diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel index 6bb71ae98..269e10433 100644 --- a/streaming/BUILD.bazel +++ b/streaming/BUILD.bazel @@ -9,9 +9,6 @@ proto_library( srcs = ["src/protobuf/streaming.proto"], strip_import_prefix = "src", visibility = ["//visibility:public"], - deps = [ - "@com_google_protobuf//:any_proto", - ], ) proto_library( @@ -25,10 +22,7 @@ proto_library( srcs = ["src/protobuf/remote_call.proto"], strip_import_prefix = "src", visibility = ["//visibility:public"], - deps = [ - "streaming_proto", - "@com_google_protobuf//:any_proto", - ], + deps = ["streaming_proto"], ) cc_proto_library( @@ -76,10 +70,9 @@ cc_library( "src/util/*.h", ]), copts = COPTS, - includes = ["src"], + strip_include_prefix = "src", visibility = ["//visibility:public"], deps = [ - "ray_common.so", "ray_util.so", "@boost//:any", "@com_google_googletest//:gtest", @@ -150,62 +143,6 @@ cc_library( }), ) -cc_library( - name = "streaming_channel", - srcs = glob(["src/channel/*.cc"]), - hdrs = glob(["src/channel/*.h"]), - copts = COPTS, - visibility = ["//visibility:public"], - deps = [ - ":streaming_common", - ":streaming_message", - ":streaming_queue", - ":streaming_ring_buffer", - ":streaming_util", - ], -) - -cc_library( - name = "streaming_reliability", - srcs = glob(["src/reliability/*.cc"]), - hdrs = glob(["src/reliability/*.h"]), - copts = COPTS, - includes = ["src/"], - visibility = ["//visibility:public"], - deps = [ - ":streaming_channel", - ":streaming_message", - ":streaming_util", - ], -) - -cc_library( - name = "streaming_ring_buffer", - srcs = glob(["src/ring_buffer/*.cc"]), - hdrs = glob(["src/ring_buffer/*.h"]), - copts = COPTS, - includes = ["src/"], - visibility = ["//visibility:public"], - deps = [ - "core_worker_lib.so", - ":ray_common.so", - ":ray_util.so", - ":streaming_message", - "@boost//:circular_buffer", - "@boost//:thread", - ], -) - -cc_library( - name = "streaming_common", - srcs = glob(["src/common/*.cc"]), - hdrs = glob(["src/common/*.h"]), - copts = COPTS, - includes = ["src/"], - visibility = ["//visibility:public"], - deps = [], -) - cc_library( name = "streaming_lib", srcs = glob([ @@ -222,13 +159,11 @@ cc_library( deps = [ "ray_common.so", "ray_util.so", - ":streaming_channel", - ":streaming_common", ":streaming_config", ":streaming_message", ":streaming_queue", - ":streaming_reliability", ":streaming_util", + "@boost//:circular_buffer", ], ) @@ -349,7 +284,6 @@ genrule( mkdir -p "$$GENERATED_DIR" touch "$$GENERATED_DIR/__init__.py" sed -i -E 's/from streaming.src.protobuf/from ./' "$$GENERATED_DIR/remote_call_pb2.py" - sed -i -E 's/from protobuf/from ./' "$$GENERATED_DIR/remote_call_pb2.py" date > $@ """, local = 1, @@ -364,6 +298,7 @@ cc_binary( ]), copts = COPTS, linkshared = 1, + linkstatic = 1, visibility = ["//visibility:public"], deps = [ ":streaming_lib", diff --git a/streaming/java/BUILD.bazel b/streaming/java/BUILD.bazel index 8ee94c597..91de8130e 100644 --- a/streaming/java/BUILD.bazel +++ b/streaming/java/BUILD.bazel @@ -127,12 +127,10 @@ define_java_module( ":io_ray_ray_streaming-state", "//java:io_ray_ray_api", "//java:io_ray_ray_runtime", - "@maven//:commons_io_commons_io", "@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java", "@ray_streaming_maven//:com_google_code_findbugs_jsr305", "@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:com_google_protobuf_protobuf_java", - "@ray_streaming_maven//:commons_collections_commons_collections", "@ray_streaming_maven//:de_ruedigermoeller_fst", "@ray_streaming_maven//:org_aeonbits_owner_owner", "@ray_streaming_maven//:org_apache_commons_commons_lang3", diff --git a/streaming/java/checkstyle-suppressions.xml b/streaming/java/checkstyle-suppressions.xml index 1d86cdba0..cb07198ed 100644 --- a/streaming/java/checkstyle-suppressions.xml +++ b/streaming/java/checkstyle-suppressions.xml @@ -11,7 +11,4 @@ - - - diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl index b834f6a39..1fe083f99 100644 --- a/streaming/java/dependencies.bzl +++ b/streaming/java/dependencies.bzl @@ -25,7 +25,6 @@ def gen_streaming_java_deps(): "org.mockito:mockito-all:1.10.19", "org.powermock:powermock-module-testng:1.6.6", "org.powermock:powermock-api-mockito:1.6.6", - "commons-collections:commons-collections:3.2.1", ], repositories = [ "https://repo.spring.io/plugins-release/", diff --git a/streaming/java/generate_jni_header_files.sh b/streaming/java/generate_jni_header_files.sh deleted file mode 100755 index 5ce3cb7a2..000000000 --- a/streaming/java/generate_jni_header_files.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env bash - -set -e -set -x - -cd "$(dirname "$0")" - -bazel build all_streaming_tests_deploy.jar - -function generate_one() -{ - file=${1//./_}.h - javah -classpath ../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar "$1" - - # prepend licence first - cat < ../src/lib/java/"$file" -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -EOF - # then append the generated header file - cat "$file" >> ../src/lib/java/"$file" - rm -f "$file" -} - -generate_one io.ray.streaming.runtime.transfer.channel.ChannelId -generate_one io.ray.streaming.runtime.transfer.DataReader -generate_one io.ray.streaming.runtime.transfer.DataWriter -generate_one io.ray.streaming.runtime.transfer.TransferHandler - -rm -f io_ray_streaming_*.h diff --git a/streaming/java/pom.xml b/streaming/java/pom.xml index 003c7670a..3432e006e 100644 --- a/streaming/java/pom.xml +++ b/streaming/java/pom.xml @@ -65,6 +65,7 @@ 27.0.1-jre 2.57 + release diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java index 82791a622..f29915381 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java @@ -7,26 +7,4 @@ import java.io.Serializable; */ public interface Function extends Serializable { - /** - * This method will be called periodically by framework, you should return a a serializable - * object which represents function state, framework will help you to serialize this object, save - * it to storage, and load it back when in fail-over through. - * {@link Function#loadCheckpoint(Serializable)}. - * - * @return A serializable object which represents function state. - */ - default Serializable saveCheckpoint() { - return null; - } - - /** - * This method will be called by framework when a worker died and been restarted. - * We will pass the last object you returned in {@link Function#saveCheckpoint()} when - * doing checkpoint, you are responsible to load this object back to you function. - * - * @param checkpointObject the last object you returned in {@link Function#saveCheckpoint()} - */ - default void loadCheckpoint(Serializable checkpointObject) { - } - } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java index 96900d841..40135b34b 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java @@ -9,9 +9,9 @@ import io.ray.streaming.api.function.Function; */ public interface SourceFunction extends Function { - void init(int parallelism, int index); + void init(int parallel, int index); - void fetch(SourceContext ctx) throws Exception; + void run(SourceContext ctx) throws Exception; void close(); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java index ec63b7d7e..b14aa9a6c 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java @@ -1,6 +1,7 @@ package io.ray.streaming.api.function.internal; import io.ray.streaming.api.function.impl.SourceFunction; +import java.util.ArrayList; import java.util.Collection; /** @@ -11,25 +12,22 @@ import java.util.Collection; public class CollectionSourceFunction implements SourceFunction { private Collection values; - private boolean finished = false; public CollectionSourceFunction(Collection values) { this.values = values; } @Override - public void init(int totalParallel, int currentIndex) { + public void init(int parallel, int index) { } @Override - public void fetch(SourceContext ctx) throws Exception { - if (finished) { - return; - } + public void run(SourceContext ctx) throws Exception { for (T value : values) { ctx.collect(value); } - finished = true; + // empty collection + values = new ArrayList<>(); } @Override diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java index 4eb655689..c99ec9959 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java @@ -1,5 +1,6 @@ package io.ray.streaming.message; + import java.util.Objects; public class KeyRecord extends Record { diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java index d054b95a7..0bbb0d7a2 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java @@ -25,13 +25,4 @@ public interface Operator extends Serializable { ChainStrategy getChainStrategy(); - /** - * See {@link Function#saveCheckpoint()}. - */ - Serializable saveCheckpoint(); - - /** - * See {@link Function#loadCheckpoint(Serializable)}. - */ - void loadCheckpoint(Serializable checkpointObject); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java index 11f35f495..3cf9ab1d7 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java @@ -4,7 +4,7 @@ import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; public interface SourceOperator extends Operator { - void fetch(); + void run(); SourceContext getSourceContext(); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java index fda6c5d0e..67bc77381 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java @@ -8,7 +8,6 @@ import io.ray.streaming.api.function.RichFunction; import io.ray.streaming.api.function.internal.Functions; import io.ray.streaming.message.KeyRecord; import io.ray.streaming.message.Record; -import java.io.Serializable; import java.util.List; public abstract class StreamOperator implements Operator { @@ -73,16 +72,6 @@ public abstract class StreamOperator implements Operator { } } - @Override - public Serializable saveCheckpoint() { - return function.saveCheckpoint(); - } - - @Override - public void loadCheckpoint(Serializable checkpointObject) { - function.loadCheckpoint(checkpointObject); - } - @Override public String getName() { return name; diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java index 3a4e32cbb..c7c9e7a18 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java @@ -13,7 +13,6 @@ import io.ray.streaming.operator.OperatorType; import io.ray.streaming.operator.SourceOperator; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.operator.TwoInputOperator; -import java.io.Serializable; import java.lang.reflect.Proxy; import java.util.Collections; import java.util.List; @@ -86,23 +85,6 @@ public abstract class ChainedOperator extends StreamOperator { return tailOperator; } - @Override - public Serializable saveCheckpoint() { - Object[] checkpoints = new Object[operators.size()]; - for (int i = 0; i < operators.size(); ++i) { - checkpoints[i] = operators.get(i).saveCheckpoint(); - } - return checkpoints; - } - - @Override - public void loadCheckpoint(Serializable checkpointObject) { - Serializable[] checkpoints = (Serializable[]) checkpointObject; - for (int i = 0; i < operators.size(); ++i) { - operators.get(i).loadCheckpoint(checkpoints[i]); - } - } - private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) { return (RuntimeContext) Proxy.newProxyInstance(runtimeContext.getClass().getClassLoader(), new Class[] {RuntimeContext.class}, @@ -143,8 +125,8 @@ public abstract class ChainedOperator extends StreamOperator { } @Override - public void fetch() { - sourceOperator.fetch(); + public void run() { + sourceOperator.run(); } @Override diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java index 120701d88..495604c3a 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java @@ -29,9 +29,9 @@ public class SourceOperatorImpl extends StreamOperator> } @Override - public void fetch() { + public void run() { try { - this.function.fetch(this.sourceContext); + this.function.run(this.sourceContext); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/streaming/java/streaming-runtime/pom.xml b/streaming/java/streaming-runtime/pom.xml index 86a11ec8d..6f63fe147 100644 --- a/streaming/java/streaming-runtime/pom.xml +++ b/streaming/java/streaming-runtime/pom.xml @@ -1,8 +1,8 @@ - + + xmlns="http://maven.apache.org/POM/4.0.0" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> ray-streaming io.ray @@ -69,16 +69,6 @@ protobuf-java 3.8.0 - - commons-collections - commons-collections - 3.2.1 - - - commons-io - commons-io - 2.5 - de.ruedigermoeller fst diff --git a/streaming/java/streaming-runtime/pom_template.xml b/streaming/java/streaming-runtime/pom_template.xml index 7b8da2a9e..c1f890c12 100644 --- a/streaming/java/streaming-runtime/pom_template.xml +++ b/streaming/java/streaming-runtime/pom_template.xml @@ -1,8 +1,8 @@ - {auto_gen_header} + {auto_gen_header} + xmlns="http://maven.apache.org/POM/4.0.0" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> ray-streaming io.ray diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java index 3f1697149..ac7ede0a9 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java @@ -1,9 +1,7 @@ package io.ray.streaming.runtime.config; import com.google.common.base.Preconditions; -import io.ray.streaming.runtime.config.global.CheckpointConfig; import io.ray.streaming.runtime.config.global.CommonConfig; -import io.ray.streaming.runtime.config.global.ContextBackendConfig; import io.ray.streaming.runtime.config.global.TransferConfig; import java.io.Serializable; import java.lang.reflect.Method; @@ -21,19 +19,17 @@ import org.slf4j.LoggerFactory; public class StreamingGlobalConfig implements Serializable { private static final Logger LOG = LoggerFactory.getLogger(StreamingGlobalConfig.class); + public final CommonConfig commonConfig; public final TransferConfig transferConfig; + public final Map configMap; - public CheckpointConfig checkpointConfig; - public ContextBackendConfig contextBackendConfig; public StreamingGlobalConfig(final Map conf) { configMap = new HashMap<>(conf); commonConfig = ConfigFactory.create(CommonConfig.class, conf); transferConfig = ConfigFactory.create(TransferConfig.class, conf); - checkpointConfig = ConfigFactory.create(CheckpointConfig.class, conf); - contextBackendConfig = ConfigFactory.create(ContextBackendConfig.class, conf); globalConfig2Map(); } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java deleted file mode 100644 index b31bc7d8c..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java +++ /dev/null @@ -1,55 +0,0 @@ -package io.ray.streaming.runtime.config.global; - -import io.ray.streaming.runtime.config.Config; -import org.aeonbits.owner.Mutable; - -/** - * Configurations for checkpointing. - */ -public interface CheckpointConfig extends Config, Mutable { - - String CP_INTERVAL_SECS = "streaming.checkpoint.interval.secs"; - String CP_TIMEOUT_SECS = "streaming.checkpoint.timeout.secs"; - - String CP_PREFIX_KEY_MASTER = "streaming.checkpoint.prefix-key.job-master.context"; - String CP_PREFIX_KEY_WORKER = "streaming.checkpoint.prefix-key.job-worker.context"; - String CP_PREFIX_KEY_OPERATOR = "streaming.checkpoint.prefix-key.job-worker.operator"; - - /** - * Checkpoint time interval. JobMaster won't trigger 2 checkpoint in less than this time interval. - */ - @DefaultValue(value = "5") - @Key(value = CP_INTERVAL_SECS) - int cpIntervalSecs(); - - /** - * How long should JobMaster wait for checkpoint to finish. When this timeout is reached and - * JobMaster hasn't received all commits from workers, JobMaster will consider this checkpoint as - * failed and trigger another checkpoint. - */ - @DefaultValue(value = "120") - @Key(value = CP_TIMEOUT_SECS) - int cpTimeoutSecs(); - - /** - * This is used for saving JobMaster's context to storage, user usually don't need to change this. - */ - @DefaultValue(value = "job_master_runtime_context_") - @Key(value = CP_PREFIX_KEY_MASTER) - String jobMasterContextCpPrefixKey(); - - /** - * This is used for saving JobWorker's context to storage, user usually don't need to change this. - */ - @DefaultValue(value = "job_worker_context_") - @Key(value = CP_PREFIX_KEY_WORKER) - String jobWorkerContextCpPrefixKey(); - - /** - * This is used for saving user operator(in StreamTask)'s context to storage, user usually don't - * need to change this. - */ - @DefaultValue(value = "job_worker_op_") - @Key(value = CP_PREFIX_KEY_OPERATOR) - String jobWorkerOpCpPrefixKey(); -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java deleted file mode 100644 index 11d1d3371..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java +++ /dev/null @@ -1,17 +0,0 @@ -package io.ray.streaming.runtime.config.global; - -import org.aeonbits.owner.Config; - -public interface ContextBackendConfig extends Config { - - String STATE_BACKEND_TYPE = "streaming.context-backend.type"; - String FILE_STATE_ROOT_PATH = "streaming.context-backend.file-state.root"; - - @Config.DefaultValue(value = "memory") - @Key(value = STATE_BACKEND_TYPE) - String stateBackendType(); - - @Config.DefaultValue(value = "/tmp/ray_streaming_state") - @Key(value = FILE_STATE_ROOT_PATH) - String fileStateRootPath(); -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java index e6ea60d7a..7508dee2d 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java @@ -22,6 +22,13 @@ public interface TransferConfig extends Config { @Key(value = io.ray.streaming.util.Config.CHANNEL_SIZE) long channelSize(); + /** + * DataRead read timeout. + */ + @DefaultValue(value = "false") + @Key(value = io.ray.streaming.util.Config.IS_RECREATE) + boolean readerIsRecreate(); + /** * Return from DataReader.getBundle if only empty message read in this interval. */ diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java deleted file mode 100644 index 329e88c9a..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java +++ /dev/null @@ -1,22 +0,0 @@ -package io.ray.streaming.runtime.config.types; - -public enum ContextBackendType { - - /** - * Memory type - */ - MEMORY("memory", 0), - - /** - * Local File - */ - LOCAL_FILE("local_file", 1); - - private String name; - private int index; - - ContextBackendType(String name, int index) { - this.name = name; - this.index = index; - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java deleted file mode 100644 index b14cdcbb9..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java +++ /dev/null @@ -1,42 +0,0 @@ -package io.ray.streaming.runtime.context; - -import io.ray.streaming.runtime.master.JobMaster; -import io.ray.streaming.runtime.worker.JobWorker; - -/** - * This interface is used for storing context of {@link JobWorker} and {@link JobMaster}. - * The checkpoint returned by user function is also saved using this interface. - */ -public interface ContextBackend { - - /** - * check if key exists in state - * - * @return true if exists - */ - boolean exists(final String key) throws Exception; - - /** - * get content by key - * - * @param key key - * @return the StateBackend - */ - byte[] get(final String key) throws Exception; - - /** - * put content by key - * - * @param key key - * @param value content - */ - void put(final String key, final byte[] value) throws Exception; - - /** - * remove content by key - * - * @param key key - */ - void remove(final String key) throws Exception; - -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java deleted file mode 100644 index 2ca96b5de..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java +++ /dev/null @@ -1,27 +0,0 @@ -package io.ray.streaming.runtime.context; - -import io.ray.streaming.runtime.config.StreamingGlobalConfig; -import io.ray.streaming.runtime.config.types.ContextBackendType; -import io.ray.streaming.runtime.context.impl.AtomicFsBackend; -import io.ray.streaming.runtime.context.impl.MemoryContextBackend; - -public class ContextBackendFactory { - - public static ContextBackend getContextBackend(final StreamingGlobalConfig config) { - ContextBackend contextBackend; - ContextBackendType type = ContextBackendType.valueOf( - config.contextBackendConfig.stateBackendType().toUpperCase()); - - switch (type) { - case MEMORY: - contextBackend = new MemoryContextBackend(config.contextBackendConfig); - break; - case LOCAL_FILE: - contextBackend = new AtomicFsBackend(config.contextBackendConfig); - break; - default: - throw new RuntimeException("Unsupported context backend type."); - } - return contextBackend; - } -} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java deleted file mode 100644 index 85bceb13c..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java +++ /dev/null @@ -1,52 +0,0 @@ -package io.ray.streaming.runtime.context; - -import com.google.common.base.MoreObjects; -import io.ray.streaming.runtime.transfer.channel.OffsetInfo; -import java.io.Serializable; -import java.util.HashMap; -import java.util.Map; - -/** - * This data structure contains state information of a task. - */ -public class OperatorCheckpointInfo implements Serializable { - - /** - * key: channel ID, value: offset - */ - public Map inputPoints; - public Map outputPoints; - - /** - * a serializable checkpoint returned by processor - */ - public Serializable processorCheckpoint; - public long checkpointId; - - public OperatorCheckpointInfo() { - inputPoints = new HashMap<>(); - outputPoints = new HashMap<>(); - checkpointId = -1; - } - - public OperatorCheckpointInfo( - Map inputPoints, - Map outputPoints, - Serializable processorCheckpoint, - long checkpointId) { - this.inputPoints = inputPoints; - this.outputPoints = outputPoints; - this.checkpointId = checkpointId; - this.processorCheckpoint = processorCheckpoint; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("inputPoints", inputPoints) - .add("outputPoints", outputPoints) - .add("processorCheckpoint", processorCheckpoint) - .add("checkpointId", checkpointId) - .toString(); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java deleted file mode 100644 index 96288e281..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java +++ /dev/null @@ -1,48 +0,0 @@ -package io.ray.streaming.runtime.context.impl; - -import io.ray.streaming.runtime.config.global.ContextBackendConfig; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Achieves an atomic `put` method. - * known issue: if you crashed while write a key at first time, this code will not work. - */ -public class AtomicFsBackend extends LocalFileContextBackend { - - private static final Logger LOG = LoggerFactory.getLogger(AtomicFsBackend.class); - private static final String TMP_FLAG = "_tmp"; - - public AtomicFsBackend(final ContextBackendConfig config) { - super(config); - } - - @Override - public byte[] get(String key) throws Exception { - String tmpKey = key + TMP_FLAG; - if (super.exists(tmpKey) && !super.exists(key)) { - return super.get(tmpKey); - } - return super.get(key); - } - - @Override - public void put(String key, byte[] value) throws Exception { - String tmpKey = key + TMP_FLAG; - if (super.exists(tmpKey) && !super.exists(key)) { - super.rename(tmpKey, key); - } - super.put(tmpKey, value); - super.remove(key); - super.rename(tmpKey, key); - } - - @Override - public void remove(String key) { - String tmpKey = key + TMP_FLAG; - if (super.exists(tmpKey)) { - super.remove(tmpKey); - } - super.remove(key); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java deleted file mode 100644 index 41e180462..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java +++ /dev/null @@ -1,55 +0,0 @@ -package io.ray.streaming.runtime.context.impl; - -import io.ray.streaming.runtime.config.global.ContextBackendConfig; -import io.ray.streaming.runtime.context.ContextBackend; -import java.io.File; -import org.apache.commons.io.FileUtils; - -/** - * This context backend uses local file system and doesn't supports failover in cluster. - * But it supports failover in single node. - * This is a pure file system backend which doesn't support atomic writing, please don't use this - * class, instead, use {@link AtomicFsBackend} which extends this class. - */ -public class LocalFileContextBackend implements ContextBackend { - - private final String rootPath; - - - public LocalFileContextBackend(ContextBackendConfig config) { - rootPath = config.fileStateRootPath(); - } - - @Override - public boolean exists(String key) { - File file = new File(rootPath, key); - return file.exists(); - } - - @Override - public byte[] get(String key) throws Exception { - File file = new File(rootPath, key); - if (file.exists()) { - return FileUtils.readFileToByteArray(file); - } - return null; - } - - @Override - public void put(String key, byte[] value) throws Exception { - File file = new File(rootPath, key); - FileUtils.writeByteArrayToFile(file, value); - } - - @Override - public void remove(String key) { - File file = new File(rootPath, key); - FileUtils.deleteQuietly(file); - } - - protected void rename(String fromKey, String toKey) throws Exception { - File srcFile = new File(rootPath, fromKey); - File dstFile = new File(rootPath, toKey); - FileUtils.moveFile(srcFile, dstFile); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java deleted file mode 100644 index 0a3723e05..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java +++ /dev/null @@ -1,72 +0,0 @@ -package io.ray.streaming.runtime.context.impl; - -import io.ray.streaming.runtime.config.global.ContextBackendConfig; -import io.ray.streaming.runtime.context.ContextBackend; -import java.util.HashMap; -import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This context backend uses memory and doesn't supports failover. - * Data will be lost after worker died. - */ -public class MemoryContextBackend implements ContextBackend { - - private static final Logger LOG = LoggerFactory.getLogger(MemoryContextBackend.class); - - private final Map kvStore = new HashMap<>(); - - public MemoryContextBackend(ContextBackendConfig config) { - if (LOG.isInfoEnabled()) { - LOG.info("Start init memory state backend, config is {}.", config); - LOG.info("Finish init memory state backend."); - } - } - - @Override - public boolean exists(String key) { - return kvStore.containsKey(key); - } - - @Override - public byte[] get(final String key) { - if (LOG.isInfoEnabled()) { - LOG.info("Get value of key {} start.", key); - } - - byte[] readData = kvStore.get(key); - - if (LOG.isInfoEnabled()) { - LOG.info("Get value of key {} success.", key); - } - - return readData; - } - - @Override - public void put(final String key, final byte[] value) { - if (LOG.isInfoEnabled()) { - LOG.info("Put value of key {} start.", key); - } - - kvStore.put(key, value); - - if (LOG.isInfoEnabled()) { - LOG.info("Put value of key {} success.", key); - } - } - - @Override - public void remove(final String key) { - if (LOG.isInfoEnabled()) { - LOG.info("Remove value of key {} start.", key); - } - - kvStore.remove(key); - - if (LOG.isInfoEnabled()) { - LOG.info("Remove value of key {} success.", key); - } - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java index 90a8e2b18..1006165f3 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java @@ -9,8 +9,8 @@ import io.ray.streaming.message.Record; import io.ray.streaming.runtime.serialization.CrossLangSerializer; import io.ray.streaming.runtime.serialization.JavaSerializer; import io.ray.streaming.runtime.serialization.Serializer; +import io.ray.streaming.runtime.transfer.ChannelId; import io.ray.streaming.runtime.transfer.DataWriter; -import io.ray.streaming.runtime.transfer.channel.ChannelId; import java.nio.ByteBuffer; import java.util.Collection; import org.slf4j.Logger; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java index dcbf6b1ff..b9d19bf2d 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java @@ -1,29 +1,19 @@ package io.ray.streaming.runtime.core.graph.executiongraph; -import com.google.common.collect.Sets; import io.ray.api.BaseActorHandle; -import io.ray.api.id.ActorId; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Physical plan. */ public class ExecutionGraph implements Serializable { - private static final Logger LOG = LoggerFactory.getLogger(ExecutionGraph.class); - /** * Name of the job. */ @@ -39,27 +29,6 @@ public class ExecutionGraph implements Serializable { */ private Map executionJobVertexMap; - /** - * Data map for execution vertex. - * key: execution vertex id. - * value: execution vertex. - */ - private Map executionVertexMap; - - /** - * Data map for execution vertex. - * key: actor id. - * value: execution vertex. - */ - private Map actorIdExecutionVertexMap; - - - /** - * key: channel ID - * value: actors in both sides of this channel - */ - private Map> channelGroupedActors; - /** * The max parallelism of the whole graph. */ @@ -85,7 +54,7 @@ public class ExecutionGraph implements Serializable { } public List getExecutionJobVertexList() { - return new ArrayList<>(executionJobVertexMap.values()); + return new ArrayList(executionJobVertexMap.values()); } public Map getExecutionJobVertexMap() { @@ -96,58 +65,6 @@ public class ExecutionGraph implements Serializable { this.executionJobVertexMap = executionJobVertexMap; } - - /** - * generate relation mappings between actors, execution vertices and channels - * this method must be called after worker actor is set. - */ - public void generateActorMappings() { - LOG.info("Setup queue actors relation."); - - channelGroupedActors = new HashMap<>(); - actorIdExecutionVertexMap = new HashMap<>(); - - getAllExecutionVertices().forEach(curVertex -> { - - // current - actorIdExecutionVertexMap.put(curVertex.getActorId(), curVertex); - - // input - List inputEdges = curVertex.getInputEdges(); - inputEdges.forEach(inputEdge -> { - ExecutionVertex inputVertex = inputEdge.getSourceExecutionVertex(); - String channelId = curVertex.getChannelIdByPeerVertex(inputVertex); - addActorToChannelGroupedActors(channelGroupedActors, channelId, - inputVertex.getWorkerActor()); - }); - - // output - List outputEdges = curVertex.getOutputEdges(); - outputEdges.forEach(outputEdge -> { - ExecutionVertex outputVertex = outputEdge.getTargetExecutionVertex(); - String channelId = curVertex.getChannelIdByPeerVertex(outputVertex); - addActorToChannelGroupedActors(channelGroupedActors, channelId, - outputVertex.getWorkerActor()); - }); - }); - - LOG.debug("Channel grouped actors is: {}.", channelGroupedActors); - } - - private void addActorToChannelGroupedActors( - Map> channelGroupedActors, - String queueName, - BaseActorHandle actor) { - - Set actorSet = - channelGroupedActors.computeIfAbsent(queueName, k -> new HashSet<>()); - actorSet.add(actor); - } - - public void setExecutionVertexMap(Map executionVertexMap) { - this.executionVertexMap = executionVertexMap; - } - public Map getJobConfig() { return jobConfig; } @@ -197,73 +114,25 @@ public class ExecutionGraph implements Serializable { return executionJobVertexMap.values().stream() .map(ExecutionJobVertex::getExecutionVertices) .flatMap(Collection::stream) - .filter(ExecutionVertex::is2Add) + .filter(vertex -> vertex.is2Add()) .collect(Collectors.toList()); } /** * Get specified execution vertex from current execution graph by execution vertex id. * - * @param executionVertexId execution vertex id. + * @param vertexId execution vertex id. * @return the specified execution vertex. */ - public ExecutionVertex getExecutionVertexByExecutionVertexId(int executionVertexId) { - if (executionVertexMap.containsKey(executionVertexId)) { - return executionVertexMap.get(executionVertexId); - } - throw new RuntimeException("Vertex " + executionVertexId + " does not exist!"); - } - - - /** - * Get specified execution vertex from current execution graph by actor id. - * - * @param actorId the actor id of execution vertex. - * @return the specified execution vertex. - */ - public ExecutionVertex getExecutionVertexByActorId(ActorId actorId) { - return actorIdExecutionVertexMap.get(actorId); - } - - - /** - * Get specified actor by actor id. - * - * @param actorId the actor id of execution vertex. - * @return the specified actor handle. - */ - public Optional getActorById(ActorId actorId) { - return getAllActors().stream() - .filter(actor -> actor.getId().equals(actorId)) - .findFirst(); - } - - /** - * Get the peer actor in the other side of channelName of a given actor - * - * @param actor actor in this side - * @param channelName the channel name - * @return the peer actor in the other side - */ - public BaseActorHandle getPeerActor(BaseActorHandle actor, String channelName) { - Set set = getActorsByChannelId(channelName); - final BaseActorHandle[] res = new BaseActorHandle[1]; - set.forEach(anActor -> { - if (!anActor.equals(actor)) { - res[0] = anActor; + public ExecutionVertex getExecutionJobVertexByJobVertexId(int vertexId) { + for (ExecutionJobVertex executionJobVertex : executionJobVertexMap.values()) { + for (ExecutionVertex executionVertex : executionJobVertex.getExecutionVertices()) { + if (executionVertex.getExecutionVertexId() == vertexId) { + return executionVertex; + } } - }); - return res[0]; - } - - /** - * Get actors in both sides of a channelId - * - * @param channelId the channelId - * @return actors in both sides - */ - public Set getActorsByChannelId(String channelId) { - return channelGroupedActors.getOrDefault(channelId, Sets.newHashSet()); + } + throw new RuntimeException("Vertex " + vertexId + " does not exist!"); } /** @@ -333,27 +202,4 @@ public class ExecutionGraph implements Serializable { .collect(Collectors.toList()); } - public Set getActorName(Set actorIds) { - return getAllExecutionVertices().stream() - .filter(executionVertex -> actorIds.contains(executionVertex.getActorId())) - .map(ExecutionVertex::getActorName) - .collect(Collectors.toSet()); - } - - public String getActorName(ActorId actorId) { - Set set = Sets.newHashSet(); - set.add(actorId); - Set result = getActorName(set); - if (result.isEmpty()) { - return null; - } - return result.iterator().next(); - } - - public List getAllActorsId() { - return getAllActors().stream() - .map(BaseActorHandle::getId) - .collect(Collectors.toList()); - } - } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java index 6aa7936b2..6ab2fd911 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java @@ -3,12 +3,11 @@ package io.ray.streaming.runtime.core.graph.executiongraph; import com.google.common.base.MoreObjects; import io.ray.streaming.api.partition.Partition; import io.ray.streaming.jobgraph.JobEdge; -import java.io.Serializable; /** * An edge that connects two execution job vertices. */ -public class ExecutionJobEdge implements Serializable { +public class ExecutionJobEdge { /** * The source(upstream) execution job vertex. diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java index b617cc053..f0c87bd0f 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java @@ -8,7 +8,6 @@ import io.ray.streaming.jobgraph.JobVertex; import io.ray.streaming.jobgraph.VertexType; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.runtime.config.master.ResourceConfig; -import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -21,7 +20,7 @@ import org.aeonbits.owner.ConfigFactory; *

Execution job vertex is the physical form of {@link JobVertex} and * every execution job vertex is corresponding to a group of {@link ExecutionVertex}. */ -public class ExecutionJobVertex implements Serializable { +public class ExecutionJobVertex { /** * Unique id. Use {@link JobVertex}'s id directly. diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java index 5d6a2556c..0135b35ed 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java @@ -9,7 +9,6 @@ import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.runtime.config.master.ResourceConfig; import io.ray.streaming.runtime.core.resource.ContainerId; import io.ray.streaming.runtime.core.resource.ResourceType; -import io.ray.streaming.runtime.transfer.channel.ChannelId; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; @@ -61,8 +60,6 @@ public class ExecutionVertex implements Serializable { */ private ContainerId containerId; - private String pid; - /** * Worker actor handle. */ @@ -76,14 +73,6 @@ public class ExecutionVertex implements Serializable { private List inputEdges = new ArrayList<>(); private List outputEdges = new ArrayList<>(); - private transient List outputChannelIdList; - private transient List inputChannelIdList; - - private transient List outputActorList; - private transient List inputActorList; - private Map exeVertexChannelMap; - - public ExecutionVertex( int globalIndex, int index, @@ -103,7 +92,9 @@ public class ExecutionVertex implements Serializable { } private Map genWorkerConfig(Map jobConfig) { - return new HashMap<>(jobConfig); + Map workerConfig = new HashMap<>(); + workerConfig.putAll(jobConfig); + return workerConfig; } public int getExecutionVertexId() { @@ -170,14 +161,14 @@ public class ExecutionVertex implements Serializable { return workerActor; } - public void setWorkerActor(BaseActorHandle workerActor) { - this.workerActor = workerActor; - } - public ActorId getWorkerActorId() { return workerActor.getId(); } + public void setWorkerActor(BaseActorHandle workerActor) { + this.workerActor = workerActor; + } + public List getInputEdges() { return inputEdges; } @@ -208,14 +199,6 @@ public class ExecutionVertex implements Serializable { .collect(Collectors.toList()); } - public ActorId getActorId() { - return null == workerActor ? null : workerActor.getId(); - } - - public String getActorName() { - return String.valueOf(executionVertexId); - } - public Map getResource() { return resource; } @@ -236,89 +219,12 @@ public class ExecutionVertex implements Serializable { this.containerId = containerId; } - public String getPid() { - return pid; - } - - public void setPid(String pid) { - this.pid = pid; - } - public void setContainerIfNotExist(ContainerId containerId) { if (null == this.containerId) { this.containerId = containerId; } } - /*---------channel-actor relations---------*/ - public List getOutputChannelIdList() { - if (outputChannelIdList == null) { - generateActorChannelInfo(); - } - return outputChannelIdList; - } - - public List getOutputActorList() { - if (outputActorList == null) { - generateActorChannelInfo(); - } - return outputActorList; - } - - public List getInputChannelIdList() { - if (inputChannelIdList == null) { - generateActorChannelInfo(); - } - return inputChannelIdList; - } - - public List getInputActorList() { - if (inputActorList == null) { - generateActorChannelInfo(); - } - return inputActorList; - } - - - public String getChannelIdByPeerVertex(ExecutionVertex peerVertex) { - if (exeVertexChannelMap == null) { - generateActorChannelInfo(); - } - return exeVertexChannelMap.get(peerVertex.getExecutionVertexId()); - } - - - private void generateActorChannelInfo() { - inputChannelIdList = new ArrayList<>(); - inputActorList = new ArrayList<>(); - outputChannelIdList = new ArrayList<>(); - outputActorList = new ArrayList<>(); - exeVertexChannelMap = new HashMap<>(); - - List inputEdges = getInputEdges(); - for (ExecutionEdge edge : inputEdges) { - String channelId = ChannelId.genIdStr( - edge.getSourceExecutionVertex().getExecutionVertexId(), - getExecutionVertexId(), - getBuildTime()); - inputChannelIdList.add(channelId); - inputActorList.add(edge.getSourceExecutionVertex().getWorkerActor()); - exeVertexChannelMap.put(edge.getSourceExecutionVertex().getExecutionVertexId(), channelId); - } - - List outputEdges = getOutputEdges(); - for (ExecutionEdge edge : outputEdges) { - String channelId = ChannelId.genIdStr( - getExecutionVertexId(), - edge.getTargetExecutionVertex().getExecutionVertexId(), - getBuildTime()); - outputChannelIdList.add(channelId); - outputActorList.add(edge.getTargetExecutionVertex().getWorkerActor()); - exeVertexChannelMap.put(edge.getTargetExecutionVertex().getExecutionVertexId(), channelId); - } - } - - private Map generateResources(ResourceConfig resourceConfig) { Map resourceMap = new HashMap<>(); if (resourceConfig.isTaskCpuResourceLimit()) { diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java index d189c42c1..500721d3d 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java @@ -15,7 +15,7 @@ public class ProcessBuilder { public static StreamProcessor buildProcessor(StreamOperator streamOperator) { OperatorType type = streamOperator.getOpType(); LOGGER.info("Building StreamProcessor, operator type = {}, operator = {}.", type, - streamOperator.getClass().getSimpleName()); + streamOperator.getClass().getSimpleName().toString()); switch (type) { case SOURCE: return new SourceProcessor<>((SourceOperator) streamOperator); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java index 54fe76cd8..3b128376c 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java @@ -2,7 +2,6 @@ package io.ray.streaming.runtime.core.processor; import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.context.RuntimeContext; -import io.ray.streaming.api.function.Function; import java.io.Serializable; import java.util.List; @@ -12,15 +11,5 @@ public interface Processor extends Serializable { void process(T t); - /** - * See {@link Function#saveCheckpoint()}. - */ - Serializable saveCheckpoint(); - - /** - * See {@link Function#loadCheckpoint(Serializable)}. - */ - void loadCheckpoint(Serializable checkpointObject); - void close(); } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java index 1cc721a2a..020f39d16 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java @@ -19,8 +19,8 @@ public class SourceProcessor extends StreamProcessor implements Processo LOGGER.info("opened {}", this); } - @Override - public Serializable saveCheckpoint() { - return operator.saveCheckpoint(); - } - - @Override - public void loadCheckpoint(Serializable checkpointObject) { - operator.loadCheckpoint(checkpointObject); - } - @Override public String toString() { return this.getClass().getSimpleName(); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java index 6115e4d50..60d3a0843 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java @@ -1,36 +1,18 @@ package io.ray.streaming.runtime.master; import com.google.common.base.Preconditions; -import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.ActorHandle; -import io.ray.api.BaseActorHandle; -import io.ray.api.Ray; -import io.ray.api.id.ActorId; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.runtime.config.StreamingConfig; import io.ray.streaming.runtime.config.StreamingMasterConfig; -import io.ray.streaming.runtime.context.ContextBackend; -import io.ray.streaming.runtime.context.ContextBackendFactory; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; -import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; -import io.ray.streaming.runtime.core.resource.Container; -import io.ray.streaming.runtime.generated.RemoteCall; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; -import io.ray.streaming.runtime.master.coordinator.CheckpointCoordinator; -import io.ray.streaming.runtime.master.coordinator.FailoverCoordinator; -import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; -import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; import io.ray.streaming.runtime.master.graphmanager.GraphManager; import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; import io.ray.streaming.runtime.master.resourcemanager.ResourceManager; import io.ray.streaming.runtime.master.resourcemanager.ResourceManagerImpl; import io.ray.streaming.runtime.master.scheduler.JobSchedulerImpl; -import io.ray.streaming.runtime.util.CheckpointStateUtil; -import io.ray.streaming.runtime.util.ResourceUtil; -import io.ray.streaming.runtime.util.Serializer; import io.ray.streaming.runtime.worker.JobWorker; import java.util.Map; -import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,68 +24,33 @@ public class JobMaster { private static final Logger LOG = LoggerFactory.getLogger(JobMaster.class); - private JobMasterRuntimeContext runtimeContext; + private JobRuntimeContext runtimeContext; private ResourceManager resourceManager; private JobSchedulerImpl scheduler; private GraphManager graphManager; private StreamingMasterConfig conf; - private ContextBackend contextBackend; - private ActorHandle jobMasterActor; - // coordinators - private CheckpointCoordinator checkpointCoordinator; - private FailoverCoordinator failoverCoordinator; - public JobMaster(Map confMap) { LOG.info("Creating job master with conf: {}.", confMap); StreamingConfig streamingConfig = new StreamingConfig(confMap); this.conf = streamingConfig.masterConfig; - this.contextBackend = ContextBackendFactory.getContextBackend(this.conf); // init runtime context - runtimeContext = new JobMasterRuntimeContext(streamingConfig); - - // load checkpoint if is recover - if (Ray.getRuntimeContext().wasCurrentActorRestarted()) { - loadMasterCheckpoint(); - } + runtimeContext = new JobRuntimeContext(streamingConfig); LOG.info("Finished creating job master."); } - public static String getJobMasterRuntimeContextKey(StreamingMasterConfig conf) { - return conf.checkpointConfig.jobMasterContextCpPrefixKey() + conf.commonConfig.jobName(); - } - - private void loadMasterCheckpoint() { - LOG.info("Start to load JobMaster's checkpoint."); - // recover runtime context - byte[] bytes = - CheckpointStateUtil.get(contextBackend, getJobMasterRuntimeContextKey(getConf())); - if (bytes == null) { - LOG.warn("JobMaster got empty checkpoint from state backend. Skip loading checkpoint."); - // cp 0 was automatically saved when job started, see StreamTask. - runtimeContext.checkpointIds.add(0L); - return; - } - - this.runtimeContext = Serializer.decode(bytes); - - // FO case, triggered by ray, we need to register context when loading checkpoint - LOG.info("JobMaster recover runtime context[{}] from state backend.", runtimeContext); - init(true); - } - /** * Init JobMaster. To initiate or recover other components(like metrics and extra coordinators). * * @return init result */ - public Boolean init(boolean isRecover) { - LOG.info("Initializing job master, isRecover={}.", isRecover); + public Boolean init() { + LOG.info("Initializing job master."); if (this.runtimeContext.getExecutionGraph() == null) { LOG.error("Init job master failed. Job graphs is null."); @@ -113,14 +60,6 @@ public class JobMaster { ExecutionGraph executionGraph = graphManager.getExecutionGraph(); Preconditions.checkArgument(executionGraph != null, "no execution graph"); - // init coordinators - checkpointCoordinator = new CheckpointCoordinator(this); - checkpointCoordinator.start(); - failoverCoordinator = new FailoverCoordinator(this, isRecover); - failoverCoordinator.start(); - - saveContext(); - LOG.info("Finished initializing job master."); return true; } @@ -162,86 +101,11 @@ public class JobMaster { return true; } - public synchronized void saveContext() { - if (runtimeContext != null && getConf() != null) { - LOG.debug("Save JobMaster context."); - - byte[] contextBytes = Serializer.encode(runtimeContext); - CheckpointStateUtil - .put(contextBackend, getJobMasterRuntimeContextKey(getConf()), contextBytes); - } - } - - public byte[] reportJobWorkerCommit(byte[] reportBytes) { - Boolean ret = false; - RemoteCall.BaseWorkerCmd reportPb; - try { - reportPb = RemoteCall.BaseWorkerCmd.parseFrom(reportBytes); - ActorId actorId = ActorId.fromBytes(reportPb.getActorId().toByteArray()); - long remoteCallCost = System.currentTimeMillis() - reportPb.getTimestamp(); - LOG.info("Vertex {}, request job worker commit cost {}ms, actorId={}.", - getExecutionVertex(actorId), remoteCallCost, actorId); - RemoteCall.WorkerCommitReport commit = - reportPb.getDetail().unpack(RemoteCall.WorkerCommitReport.class); - WorkerCommitReport report = new WorkerCommitReport(actorId, commit.getCommitCheckpointId()); - ret = checkpointCoordinator.reportJobWorkerCommit(report); - } catch (InvalidProtocolBufferException e) { - LOG.error("Parse job worker commit has exception.", e); - } - return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray(); - } - - public byte[] requestJobWorkerRollback(byte[] requestBytes) { - Boolean ret = false; - RemoteCall.BaseWorkerCmd requestPb; - try { - requestPb = RemoteCall.BaseWorkerCmd.parseFrom(requestBytes); - ActorId actorId = ActorId.fromBytes(requestPb.getActorId().toByteArray()); - long remoteCallCost = System.currentTimeMillis() - requestPb.getTimestamp(); - ExecutionGraph executionGraph = graphManager.getExecutionGraph(); - Optional rayActor = executionGraph.getActorById(actorId); - if (!rayActor.isPresent()) { - LOG.warn("Skip this invalid rollback, actor id {} is not found.", actorId); - return RemoteCall.BoolResult.newBuilder().setBoolRes(false).build().toByteArray(); - } - ExecutionVertex exeVertex = getExecutionVertex(actorId); - LOG.info("Vertex {}, request job worker rollback cost {}ms, actorId={}.", - exeVertex, remoteCallCost, actorId); - RemoteCall.WorkerRollbackRequest rollbackPb - = RemoteCall.WorkerRollbackRequest.parseFrom(requestPb.getDetail().getValue()); - exeVertex.setPid(rollbackPb.getWorkerPid()); - // To find old container where slot is located in. - String hostname = ""; - Optional container = ResourceUtil.getContainerById( - resourceManager.getRegisteredContainers(), - exeVertex.getContainerId() - ); - if (container.isPresent()) { - hostname = container.get().getHostname(); - } - WorkerRollbackRequest request = new WorkerRollbackRequest( - actorId, rollbackPb.getExceptionMsg(), hostname, exeVertex.getPid() - ); - - ret = failoverCoordinator.requestJobWorkerRollback(request); - LOG.info("Vertex {} request rollback, exception msg : {}.", - exeVertex, rollbackPb.getExceptionMsg()); - - } catch (Throwable e) { - LOG.error("Parse job worker rollback has exception.", e); - } - return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray(); - } - - private ExecutionVertex getExecutionVertex(ActorId id) { - return graphManager.getExecutionGraph().getExecutionVertexByActorId(id); - } - public ActorHandle getJobMasterActor() { return jobMasterActor; } - public JobMasterRuntimeContext getRuntimeContext() { + public JobRuntimeContext getRuntimeContext() { return runtimeContext; } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java deleted file mode 100644 index c9e6e8f57..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java +++ /dev/null @@ -1,81 +0,0 @@ -package io.ray.streaming.runtime.master.context; - -import com.google.common.base.MoreObjects; -import com.google.common.collect.Sets; -import io.ray.streaming.jobgraph.JobGraph; -import io.ray.streaming.runtime.config.StreamingConfig; -import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; -import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; - -/** - * Runtime context for job master, which will be stored in backend when saving checkpoint. - * - *

Including: graph, resource, checkpoint info, etc. - */ -public class JobMasterRuntimeContext implements Serializable { - - /*--------------Checkpoint----------------*/ - public volatile List checkpointIds = new ArrayList<>(); - public volatile long lastCheckpointId = 0; - public volatile long lastCpTimestamp = 0; - public volatile BlockingQueue cpCmds = new LinkedBlockingQueue<>(); - /*--------------Failover----------------*/ - public volatile BlockingQueue foCmds = new ArrayBlockingQueue<>(8192); - public volatile Set unfinishedFoCmds = Sets.newConcurrentHashSet(); - private StreamingConfig conf; - private JobGraph jobGraph; - private volatile ExecutionGraph executionGraph; - - public JobMasterRuntimeContext(StreamingConfig conf) { - this.conf = conf; - } - - public String getJobName() { - return conf.masterConfig.commonConfig.jobName(); - } - - public StreamingConfig getConf() { - return conf; - } - - public JobGraph getJobGraph() { - return jobGraph; - } - - public void setJobGraph(JobGraph jobGraph) { - this.jobGraph = jobGraph; - } - - public ExecutionGraph getExecutionGraph() { - return executionGraph; - } - - public void setExecutionGraph(ExecutionGraph executionGraph) { - this.executionGraph = executionGraph; - } - - public Long getLastValidCheckpointId() { - if (checkpointIds.isEmpty()) { - // OL is invalid checkpoint id, worker will pass it - return 0L; - } - return checkpointIds.get(checkpointIds.size() - 1); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("jobGraph", jobGraph) - .add("executionGraph", executionGraph) - .add("conf", conf.getMap()) - .toString(); - } - -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java deleted file mode 100644 index ece4de4b7..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java +++ /dev/null @@ -1,44 +0,0 @@ -package io.ray.streaming.runtime.master.coordinator; - -import io.ray.api.Ray; -import io.ray.streaming.runtime.master.JobMaster; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; -import io.ray.streaming.runtime.master.graphmanager.GraphManager; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public abstract class BaseCoordinator implements Runnable { - - private static final Logger LOG = LoggerFactory.getLogger(BaseCoordinator.class); - - protected final JobMaster jobMaster; - - protected final JobMasterRuntimeContext runtimeContext; - protected final GraphManager graphManager; - protected volatile boolean closed; - private Thread thread; - - public BaseCoordinator(JobMaster jobMaster) { - this.jobMaster = jobMaster; - this.runtimeContext = jobMaster.getRuntimeContext(); - this.graphManager = jobMaster.getGraphManager(); - } - - public void start() { - thread = new Thread(Ray.wrapRunnable(this), - this.getClass().getName() + "-" + System.currentTimeMillis()); - thread.start(); - } - - public void stop() { - closed = true; - - try { - if (thread != null) { - thread.join(30000); - } - } catch (InterruptedException e) { - LOG.error("Coordinator thread exit has exception.", e); - } - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java deleted file mode 100644 index 862528776..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java +++ /dev/null @@ -1,215 +0,0 @@ -package io.ray.streaming.runtime.master.coordinator; - -import com.google.common.base.Preconditions; -import io.ray.api.BaseActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.id.ActorId; -import io.ray.runtime.exception.RayException; -import io.ray.streaming.runtime.master.JobMaster; -import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; -import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; -import io.ray.streaming.runtime.rpc.RemoteCallWorker; -import io.ray.streaming.runtime.worker.JobWorker; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * CheckpointCoordinator is the controller of checkpoint, responsible for triggering checkpoint, - * collecting {@link JobWorker}'s reports and calling {@link JobWorker} to clear expired - * checkpoints when new checkpoint finished. - */ -public class CheckpointCoordinator extends BaseCoordinator { - - private static final Logger LOG = LoggerFactory.getLogger(CheckpointCoordinator.class); - private final Set pendingCheckpointActors = new HashSet<>(); - private final Set interruptedCheckpointSet = new HashSet<>(); - private final int cpIntervalSecs; - private final int cpTimeoutSecs; - - public CheckpointCoordinator(JobMaster jobMaster) { - super(jobMaster); - - // get checkpoint interval from conf - this.cpIntervalSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpIntervalSecs(); - this.cpTimeoutSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpTimeoutSecs(); - - // Trigger next checkpoint in interval by reset last checkpoint timestamp. - runtimeContext.lastCpTimestamp = System.currentTimeMillis(); - } - - @Override - public void run() { - while (!closed) { - try { - final BaseWorkerCmd command = runtimeContext.cpCmds.poll(1, TimeUnit.SECONDS); - if (command != null) { - if (command instanceof WorkerCommitReport) { - processCommitReport((WorkerCommitReport) command); - } else { - interruptCheckpoint(); - } - } - - if (!pendingCheckpointActors.isEmpty()) { - // if wait commit report timeout, this cp fail, and restart next cp - if (timeoutOnWaitCheckpoint()) { - LOG.warn("Waiting for checkpoint {} timeout, pending cp actors is {}.", - runtimeContext.lastCheckpointId, - graphManager.getExecutionGraph().getActorName(pendingCheckpointActors)); - - interruptCheckpoint(); - } - } else { - maybeTriggerCheckpoint(); - } - } catch (Throwable e) { - LOG.error("Checkpoint coordinator occur err.", e); - try { - interruptCheckpoint(); - } catch (Throwable interruptE) { - LOG.error("Ignore interrupt checkpoint exception in catch block."); - } - } - } - LOG.warn("Checkpoint coordinator thread exit."); - } - - public Boolean reportJobWorkerCommit(WorkerCommitReport report) { - LOG.info("Report job worker commit {}.", report); - - Boolean ret = runtimeContext.cpCmds.offer(report); - if (!ret) { - LOG.warn("Report job worker commit failed, because command queue is full."); - } - return ret; - } - - private void processCommitReport(WorkerCommitReport commitReport) { - LOG.info("Start process commit report {}, from actor name={}.", commitReport, - graphManager.getExecutionGraph().getActorName(commitReport.fromActorId)); - - try { - Preconditions.checkArgument( - commitReport.commitCheckpointId == runtimeContext.lastCheckpointId, - "expect checkpointId %s, but got %s", - runtimeContext.lastCheckpointId, commitReport); - - if (!pendingCheckpointActors.contains(commitReport.fromActorId)) { - LOG.warn("Invalid commit report, skipped."); - return; - } - - pendingCheckpointActors.remove(commitReport.fromActorId); - LOG.info("Pending actors after this commit: {}.", - graphManager.getExecutionGraph().getActorName(pendingCheckpointActors)); - - // checkpoint finish - if (pendingCheckpointActors.isEmpty()) { - // actor finish - runtimeContext.checkpointIds.add(runtimeContext.lastCheckpointId); - - if (clearExpiredCpStateAndQueueMsg()) { - // save master context - jobMaster.saveContext(); - - LOG.info("Finish checkpoint: {}.", runtimeContext.lastCheckpointId); - } else { - LOG.warn("Fail to do checkpoint: {}.", runtimeContext.lastCheckpointId); - } - } - - LOG.info("Process commit report {} success.", commitReport); - } catch (Throwable e) { - LOG.warn("Process commit report has exception.", e); - } - } - - private void triggerCheckpoint() { - interruptedCheckpointSet.clear(); - if (LOG.isInfoEnabled()) { - LOG.info("Start trigger checkpoint {}.", runtimeContext.lastCheckpointId + 1); - } - - List allIds = graphManager.getExecutionGraph().getAllActorsId(); - // do the checkpoint - pendingCheckpointActors.addAll(allIds); - - // inc last checkpoint id - ++runtimeContext.lastCheckpointId; - - final List sourcesRet = new ArrayList<>(); - - graphManager.getExecutionGraph().getSourceActors().forEach(actor -> { - sourcesRet.add(RemoteCallWorker.triggerCheckpoint( - actor, runtimeContext.lastCheckpointId)); - }); - - for (ObjectRef rayObject : sourcesRet) { - if (rayObject.get() instanceof RayException) { - LOG.warn("Trigger checkpoint has exception.", (RayException) rayObject.get()); - throw (RayException) rayObject.get(); - } - } - runtimeContext.lastCpTimestamp = System.currentTimeMillis(); - LOG.info("Trigger checkpoint success."); - } - - private void interruptCheckpoint() { - // notify checkpoint timeout is time-consuming while many workers crash or - // container failover. - if (interruptedCheckpointSet.contains(runtimeContext.lastCheckpointId)) { - LOG.warn("Skip interrupt duplicated checkpoint id : {}.", runtimeContext.lastCheckpointId); - return; - } - interruptedCheckpointSet.add(runtimeContext.lastCheckpointId); - LOG.warn("Interrupt checkpoint, checkpoint id : {}.", runtimeContext.lastCheckpointId); - - List allActor = graphManager.getExecutionGraph().getAllActors(); - if (runtimeContext.lastCheckpointId > runtimeContext.getLastValidCheckpointId()) { - RemoteCallWorker - .notifyCheckpointTimeoutParallel(allActor, runtimeContext.lastCheckpointId); - } - - if (!pendingCheckpointActors.isEmpty()) { - pendingCheckpointActors.clear(); - } - maybeTriggerCheckpoint(); - } - - private void maybeTriggerCheckpoint() { - if (readyToTrigger()) { - triggerCheckpoint(); - } - } - - private boolean clearExpiredCpStateAndQueueMsg() { - // queue msg must clear when first checkpoint finish - List allActor = graphManager.getExecutionGraph().getAllActors(); - if (1 == runtimeContext.checkpointIds.size()) { - Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0); - RemoteCallWorker.clearExpiredCheckpointParallel(allActor, 0L, msgExpiredCheckpointId); - } - - if (runtimeContext.checkpointIds.size() > 1) { - Long stateExpiredCpId = runtimeContext.checkpointIds.remove(0); - Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0); - RemoteCallWorker - .clearExpiredCheckpointParallel(allActor, stateExpiredCpId, msgExpiredCheckpointId); - } - return true; - } - - private boolean readyToTrigger() { - return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= - cpIntervalSecs * 1000; - } - - private boolean timeoutOnWaitCheckpoint() { - return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= cpTimeoutSecs * 1000; - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java deleted file mode 100644 index c58c84d6a..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java +++ /dev/null @@ -1,281 +0,0 @@ -package io.ray.streaming.runtime.master.coordinator; - -import io.ray.api.BaseActorHandle; -import io.ray.api.id.ActorId; -import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; -import io.ray.streaming.runtime.core.resource.Container; -import io.ray.streaming.runtime.master.JobMaster; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; -import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; -import io.ray.streaming.runtime.master.coordinator.command.InterruptCheckpointRequest; -import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; -import io.ray.streaming.runtime.rpc.async.AsyncRemoteCaller; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; -import io.ray.streaming.runtime.util.ResourceUtil; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; -import org.apache.commons.collections.map.DefaultedMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class FailoverCoordinator extends BaseCoordinator { - - private static final Logger LOG = LoggerFactory.getLogger(FailoverCoordinator.class); - - private static final int ROLLBACK_RETRY_TIME_MS = 10 * 1000; - private final Object cmdLock = new Object(); - private final AsyncRemoteCaller asyncRemoteCaller; - private long currentCascadingGroupId = 0; - private final Map isRollbacking = - DefaultedMap.decorate(new ConcurrentHashMap(), false); - - public FailoverCoordinator(JobMaster jobMaster, boolean isRecover) { - this(jobMaster, new AsyncRemoteCaller(), isRecover); - } - - public FailoverCoordinator( - JobMaster jobMaster, AsyncRemoteCaller asyncRemoteCaller, - boolean isRecover) { - super(jobMaster); - - this.asyncRemoteCaller = asyncRemoteCaller; - // recover unfinished FO commands - JobMasterRuntimeContext runtimeContext = jobMaster.getRuntimeContext(); - if (isRecover) { - runtimeContext.foCmds.addAll(runtimeContext.unfinishedFoCmds); - } - runtimeContext.unfinishedFoCmds.clear(); - } - - @Override - public void run() { - while (!closed) { - try { - final BaseWorkerCmd command; - // see rollback() for lock reason - synchronized (cmdLock) { - command = jobMaster.getRuntimeContext().foCmds.poll(1, TimeUnit.SECONDS); - } - if (null == command) { - continue; - } - if (command instanceof WorkerRollbackRequest) { - jobMaster.getRuntimeContext().unfinishedFoCmds.add(command); - dealWithRollbackRequest((WorkerRollbackRequest) command); - } - } catch (Throwable e) { - LOG.error("Fo coordinator occur err.", e); - } - } - LOG.warn("Fo coordinator thread exit."); - } - - private Boolean isDuplicateRequest(WorkerRollbackRequest request) { - try { - Object[] foCmdsArray = runtimeContext.foCmds.toArray(); - for (Object cmd : foCmdsArray) { - if (request.fromActorId.equals(((BaseWorkerCmd) cmd).fromActorId)) { - return true; - } - } - } catch (Exception e) { - LOG.warn("Check request is duplicated failed.", e); - } - return false; - } - - public Boolean requestJobWorkerRollback(WorkerRollbackRequest request) { - LOG.info("Request job worker rollback {}.", request); - boolean ret; - if (!isDuplicateRequest(request)) { - ret = runtimeContext.foCmds.offer(request); - } else { - LOG.warn("Skip duplicated worker rollback request, {}.", request.toString()); - return true; - } - jobMaster.saveContext(); - if (!ret) { - LOG.warn("Request job worker rollback failed, because command queue is full."); - } - return ret; - } - - private void dealWithRollbackRequest(WorkerRollbackRequest rollbackRequest) { - LOG.info("Start deal with rollback request {}.", rollbackRequest); - - ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest); - - // Reset pid for new-rollback actor. - if (null != rollbackRequest.getPid() && - !rollbackRequest.getPid().equals(WorkerRollbackRequest.DEFAULT_PID)) { - exeVertex.setPid(rollbackRequest.getPid()); - } - - if (isRollbacking.get(exeVertex)) { - LOG.info("Vertex {} is rollbacking, skip rollback again.", exeVertex); - return; - } - - String hostname = ""; - Optional container = ResourceUtil.getContainerById( - jobMaster.getResourceManager().getRegisteredContainers(), - exeVertex.getContainerId() - ); - if (container.isPresent()) { - hostname = container.get().getHostname(); - } - - if (rollbackRequest.isForcedRollback) { - interruptCheckpointAndRollback(rollbackRequest); - } else { - asyncRemoteCaller.checkIfNeedRollbackAsync(exeVertex.getWorkerActor(), res -> { - if (!res) { - LOG.info("Vertex {} doesn't need to rollback, skip it.", exeVertex); - return; - } - interruptCheckpointAndRollback(rollbackRequest); - }, throwable -> { - LOG.error("Exception when calling checkIfNeedRollbackAsync, maybe vertex is dead" + - ", ignore this request, vertex={}.", exeVertex, throwable); - }); - } - - LOG.info("Deal with rollback request {} success.", rollbackRequest); - } - - private void interruptCheckpointAndRollback(WorkerRollbackRequest rollbackRequest) { - // assign a cascadingGroupId - if (rollbackRequest.cascadingGroupId == null) { - rollbackRequest.cascadingGroupId = currentCascadingGroupId++; - } - // get last valid checkpoint id then call worker rollback - rollback(jobMaster.getRuntimeContext().getLastValidCheckpointId(), rollbackRequest, - currentCascadingGroupId); - // we interrupt current checkpoint for 2 considerations: - // 1. current checkpoint might be timeout, because barrier might be lost after failover. so we - // interrupt current checkpoint to avoid waiting. - // 2. when we want to rollback vertex to n, job finished checkpoint n+1 and cleared state - // of checkpoint n. - jobMaster.getRuntimeContext().cpCmds.offer(new InterruptCheckpointRequest()); - } - - /** - * call worker rollback, and deal with it's reports. callback won't be finished until - * the entire DAG back to normal. - * - * @param checkpointId checkpointId to be rollback - * @param rollbackRequest worker rollback request - * @param cascadingGroupId all rollback of a cascading group should have same ID - */ - private void rollback( - long checkpointId, WorkerRollbackRequest rollbackRequest, - long cascadingGroupId) { - ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest); - LOG.info("Call vertex {} to rollback, checkpoint id is {}, cascadingGroupId={}.", - exeVertex, checkpointId, cascadingGroupId); - - isRollbacking.put(exeVertex, true); - - asyncRemoteCaller.rollback(exeVertex.getWorkerActor(), checkpointId, result -> { - List newRollbackRequests = new ArrayList<>(); - switch (result.getResultEnum()) { - case SUCCESS: - ChannelRecoverInfo recoverInfo = result.getResultObj(); - LOG.info("Vertex {} rollback done, dataLostQueues={}, msg={}, cascadingGroupId={}.", - exeVertex, recoverInfo.getDataLostQueues(), result.getResultMsg(), cascadingGroupId); - // rollback upstream if vertex reports abnormal input queues - newRollbackRequests = - cascadeUpstreamActors(recoverInfo.getDataLostQueues(), exeVertex, cascadingGroupId); - break; - case SKIPPED: - LOG.info("Vertex skip rollback, result = {}, cascadingGroupId={}.", result, - cascadingGroupId); - break; - default: - LOG.error( - "Rollback vertex {} failed, result={}, cascadingGroupId={}," + - " rollback this worker again after {} ms.", - exeVertex, result, cascadingGroupId, ROLLBACK_RETRY_TIME_MS); - Thread.sleep(ROLLBACK_RETRY_TIME_MS); - LOG.info("Add rollback request for {} again, cascadingGroupId={}.", exeVertex, - cascadingGroupId); - newRollbackRequests.add( - new WorkerRollbackRequest(exeVertex, "", "Rollback failed, try again.", false) - ); - break; - } - - // lock to avoid executing new rollback requests added. - // consider such a case: A->B->C, C cascade B, and B cascade A - // if B is rollback before B's rollback request is saved, and then JobMaster crashed, - // then A will never be rollback. - synchronized (cmdLock) { - jobMaster.getRuntimeContext().foCmds.addAll(newRollbackRequests); - // this rollback request is finished, remove it. - jobMaster.getRuntimeContext().unfinishedFoCmds.remove(rollbackRequest); - jobMaster.saveContext(); - } - isRollbacking.put(exeVertex, false); - }, throwable -> { - LOG.error("Exception when calling vertex to rollback, vertex={}.", exeVertex, throwable); - isRollbacking.put(exeVertex, false); - }); - - LOG.info("Finish rollback vertex {}, checkpoint id is {}.", exeVertex, checkpointId); - } - - private List cascadeUpstreamActors( - Set dataLostQueues, ExecutionVertex fromVertex, long cascadingGroupId) { - List cascadedRollbackRequest = new ArrayList<>(); - // rollback upstream if vertex reports abnormal input queues - dataLostQueues.forEach(q -> { - BaseActorHandle upstreamActor = - graphManager.getExecutionGraph().getPeerActor(fromVertex.getWorkerActor(), q); - ExecutionVertex upstreamExeVertex = getExecutionVertex(upstreamActor); - // vertexes that has already cascaded by other vertex in the same level - // of graph should be ignored. - if (isRollbacking.get(upstreamExeVertex)) { - return; - } - LOG.info("Call upstream vertex {} of vertex {} to rollback, cascadingGroupId={}.", - upstreamExeVertex, fromVertex, cascadingGroupId); - String hostname = ""; - Optional container = ResourceUtil.getContainerById( - jobMaster.getResourceManager().getRegisteredContainers(), - upstreamExeVertex.getContainerId() - ); - if (container.isPresent()) { - hostname = container.get().getHostname(); - } - // force upstream vertexes to rollback - WorkerRollbackRequest upstreamRequest = new WorkerRollbackRequest( - upstreamExeVertex, hostname, String.format("Cascading rollback from %s", fromVertex), true - ); - upstreamRequest.cascadingGroupId = cascadingGroupId; - cascadedRollbackRequest.add(upstreamRequest); - }); - return cascadedRollbackRequest; - } - - private ExecutionVertex getExeVertexFromRequest(WorkerRollbackRequest rollbackRequest) { - ActorId actorId = rollbackRequest.fromActorId; - Optional rayActor = graphManager.getExecutionGraph().getActorById(actorId); - if (!rayActor.isPresent()) { - throw new RuntimeException("Can not find ray actor of ID " + actorId); - } - return getExecutionVertex(rollbackRequest.fromActorId); - } - - private ExecutionVertex getExecutionVertex(BaseActorHandle actor) { - return graphManager.getExecutionGraph().getExecutionVertexByActorId(actor.getId()); - } - - private ExecutionVertex getExecutionVertex(ActorId actorId) { - return graphManager.getExecutionGraph().getExecutionVertexByActorId(actorId); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java deleted file mode 100644 index 2c6a9322d..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java +++ /dev/null @@ -1,17 +0,0 @@ -package io.ray.streaming.runtime.master.coordinator.command; - -import io.ray.api.id.ActorId; -import java.io.Serializable; - -public abstract class BaseWorkerCmd implements Serializable { - - public ActorId fromActorId; - - public BaseWorkerCmd() { - } - - protected BaseWorkerCmd(ActorId actorId) { - this.fromActorId = actorId; - } - -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java deleted file mode 100644 index 29a46ab10..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java +++ /dev/null @@ -1,5 +0,0 @@ -package io.ray.streaming.runtime.master.coordinator.command; - -public final class InterruptCheckpointRequest extends BaseWorkerCmd { - -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java deleted file mode 100644 index 7750ce1b0..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java +++ /dev/null @@ -1,22 +0,0 @@ -package io.ray.streaming.runtime.master.coordinator.command; - -import com.google.common.base.MoreObjects; -import io.ray.api.id.ActorId; - -public final class WorkerCommitReport extends BaseWorkerCmd { - - public final long commitCheckpointId; - - public WorkerCommitReport(ActorId actorId, long commitCheckpointId) { - super(actorId); - this.commitCheckpointId = commitCheckpointId; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("commitCheckpointId", commitCheckpointId) - .add("fromActorId", fromActorId) - .toString(); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java deleted file mode 100644 index e56518382..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java +++ /dev/null @@ -1,63 +0,0 @@ -package io.ray.streaming.runtime.master.coordinator.command; - -import com.google.common.base.MoreObjects; -import io.ray.api.id.ActorId; -import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; - -public final class WorkerRollbackRequest extends BaseWorkerCmd { - - public static String DEFAULT_PID = "UNKNOWN_PID"; - public Long cascadingGroupId = null; - public boolean isForcedRollback = false; - private String exceptionMsg = "No detail message."; - private String hostname = "UNKNOWN_HOST"; - private String pid = DEFAULT_PID; - - public WorkerRollbackRequest(ActorId actorId) { - super(actorId); - } - - public WorkerRollbackRequest(ActorId actorId, String msg) { - super(actorId); - exceptionMsg = msg; - } - - public WorkerRollbackRequest( - ExecutionVertex executionVertex, - String hostname, - String msg, - boolean isForcedRollback) { - - super(executionVertex.getWorkerActorId()); - - this.hostname = hostname; - this.pid = executionVertex.getPid(); - this.exceptionMsg = msg; - this.isForcedRollback = isForcedRollback; - } - - public WorkerRollbackRequest(ActorId actorId, String msg, String hostname, String pid) { - this(actorId, msg); - this.hostname = hostname; - this.pid = pid; - } - - public String getRollbackExceptionMsg() { - return exceptionMsg; - } - - public String getHostname() { - return hostname; - } - - public String getPid() { - return pid; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("fromActorId", fromActorId) - .toString(); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java index a977967ff..e76963a47 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java @@ -1,19 +1,14 @@ package io.ray.streaming.runtime.master.graphmanager; -import io.ray.api.BaseActorHandle; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobVertex; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobEdge; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; -import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; -import java.util.HashMap; -import java.util.HashSet; +import io.ray.streaming.runtime.master.JobRuntimeContext; import java.util.LinkedHashMap; import java.util.Map; -import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,9 +16,9 @@ public class GraphManagerImpl implements GraphManager { private static final Logger LOG = LoggerFactory.getLogger(GraphManagerImpl.class); - protected final JobMasterRuntimeContext runtimeContext; + protected final JobRuntimeContext runtimeContext; - public GraphManagerImpl(JobMasterRuntimeContext runtimeContext) { + public GraphManagerImpl(JobRuntimeContext runtimeContext) { this.runtimeContext = runtimeContext; } @@ -53,7 +48,6 @@ public class GraphManagerImpl implements GraphManager { // create vertex Map exeJobVertexMap = new LinkedHashMap<>(); - Map executionVertexMap = new HashMap<>(); long buildTime = executionGraph.getBuildTime(); for (JobVertex jobVertex : jobGraph.getJobVertices()) { int jobVertexId = jobVertex.getVertexId(); @@ -65,47 +59,32 @@ public class GraphManagerImpl implements GraphManager { buildTime)); } - // for each job edge, connect all source exeVertices and target exeVertices + // connect vertex jobGraph.getJobEdges().forEach(jobEdge -> { ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId()); ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId()); - ExecutionJobEdge executionJobEdge = new ExecutionJobEdge(source, target, jobEdge); + ExecutionJobEdge executionJobEdge = + new ExecutionJobEdge(source, target, jobEdge); source.getOutputEdges().add(executionJobEdge); target.getInputEdges().add(executionJobEdge); - source.getExecutionVertices().forEach(sourceExeVertex -> { - target.getExecutionVertices().forEach(targetExeVertex -> { - // pre-process some mappings - executionVertexMap.put(targetExeVertex.getExecutionVertexId(), targetExeVertex); - executionVertexMap.put(sourceExeVertex.getExecutionVertexId(), sourceExeVertex); - // build execution edge - ExecutionEdge executionEdge = - new ExecutionEdge(sourceExeVertex, targetExeVertex, executionJobEdge); - sourceExeVertex.getOutputEdges().add(executionEdge); - targetExeVertex.getInputEdges().add(executionEdge); + source.getExecutionVertices().forEach(vertex -> { + target.getExecutionVertices().forEach(outputVertex -> { + ExecutionEdge executionEdge = new ExecutionEdge(vertex, outputVertex, executionJobEdge); + vertex.getOutputEdges().add(executionEdge); + outputVertex.getInputEdges().add(executionEdge); }); }); }); // set execution job vertex into execution graph executionGraph.setExecutionJobVertexMap(exeJobVertexMap); - executionGraph.setExecutionVertexMap(executionVertexMap); return executionGraph; } - private void addActorToChannelGroupedActors( - Map> channelGroupedActors, - String channelId, - BaseActorHandle actor) { - - Set actorSet = - channelGroupedActors.computeIfAbsent(channelId, k -> new HashSet<>()); - actorSet.add(actor); - } - @Override public JobGraph getJobGraph() { return runtimeContext.getJobGraph(); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java index 2e59fed09..3b7b35ba6 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java @@ -11,7 +11,7 @@ import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.resource.Container; import io.ray.streaming.runtime.core.resource.Resources; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.JobRuntimeContext; import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategy; import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategyFactory; import io.ray.streaming.runtime.util.RayUtils; @@ -30,33 +30,39 @@ public class ResourceManagerImpl implements ResourceManager { //Container used tag private static final String CONTAINER_ENGAGED_KEY = "CONTAINER_ENGAGED_KEY"; - /** - * Resource description information. - */ - private final Resources resources; - /** - * Timing resource updating thread - */ - private final ScheduledExecutorService resourceUpdater = new ScheduledThreadPoolExecutor(1, - new ThreadFactoryBuilder().setNameFormat("resource-update-thread").build()); + /** * Job runtime context. */ - private JobMasterRuntimeContext runtimeContext; + private JobRuntimeContext runtimeContext; + /** * Resource related configuration. */ private ResourceConfig resourceConfig; + /** * Slot assign strategy. */ private ResourceAssignStrategy resourceAssignStrategy; + + /** + * Resource description information. + */ + private final Resources resources; + /** * Customized actor number for each container */ private int actorNumPerContainer; - public ResourceManagerImpl(JobMasterRuntimeContext runtimeContext) { + /** + * Timing resource updating thread + */ + private final ScheduledExecutorService resourceUpdater = new ScheduledThreadPoolExecutor(1, + new ThreadFactoryBuilder().setNameFormat("resource-update-thread").build()); + + public ResourceManagerImpl(JobRuntimeContext runtimeContext) { this.runtimeContext = runtimeContext; StreamingMasterConfig masterConfig = runtimeContext.getConf().masterConfig; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java index 238fdf6f7..6b9b3a690 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java @@ -23,18 +23,20 @@ import org.slf4j.LoggerFactory; public class JobSchedulerImpl implements JobScheduler { private static final Logger LOG = LoggerFactory.getLogger(JobSchedulerImpl.class); + + private StreamingConfig jobConf; + private final JobMaster jobMaster; private final ResourceManager resourceManager; private final GraphManager graphManager; private final WorkerLifecycleController workerLifecycleController; - private StreamingConfig jobConfig; public JobSchedulerImpl(JobMaster jobMaster) { this.jobMaster = jobMaster; this.graphManager = jobMaster.getGraphManager(); this.resourceManager = jobMaster.getResourceManager(); this.workerLifecycleController = new WorkerLifecycleController(); - this.jobConfig = jobMaster.getRuntimeContext().getConf(); + this.jobConf = jobMaster.getRuntimeContext().getConf(); LOG.info("Scheduler initiated."); } @@ -44,13 +46,8 @@ public class JobSchedulerImpl implements JobScheduler { LOG.info("Begin scheduling. Job: {}.", executionGraph.getJobName()); // Allocate resource then create workers - // Actor creation is in this step prepareResourceAndCreateWorker(executionGraph); - // now actor info is available in execution graph - // preprocess some handy mappings in execution graph - executionGraph.generateActorMappings(); - // init worker context and start to run initAndStart(executionGraph); @@ -90,7 +87,7 @@ public class JobSchedulerImpl implements JobScheduler { initMaster(); // start workers - startWorkers(executionGraph, jobMaster.getRuntimeContext().lastCheckpointId); + startWorkers(executionGraph); } /** @@ -125,7 +122,7 @@ public class JobSchedulerImpl implements JobScheduler { boolean result; try { result = workerLifecycleController.initWorkers(vertexToContextMap, - jobConfig.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs()); + jobConf.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs()); } catch (Exception e) { LOG.error("Failed to initiate workers.", e); return false; @@ -136,12 +133,11 @@ public class JobSchedulerImpl implements JobScheduler { /** * Start JobWorkers according to the physical plan. */ - public boolean startWorkers(ExecutionGraph executionGraph, long checkpointId) { + public boolean startWorkers(ExecutionGraph executionGraph) { boolean result; try { result = workerLifecycleController.startWorkers( - executionGraph, checkpointId, - jobConfig.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs()); + executionGraph, jobConf.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs()); } catch (Exception e) { LOG.error("Failed to start workers.", e); return false; @@ -198,7 +194,7 @@ public class JobSchedulerImpl implements JobScheduler { } private void initMaster() { - jobMaster.init(false); + jobMaster.init(); } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java index bc8b462c7..876e9f924 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java @@ -9,8 +9,6 @@ import io.ray.api.id.ActorId; import io.ray.streaming.api.Language; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; -import io.ray.streaming.runtime.generated.RemoteCall; -import io.ray.streaming.runtime.python.GraphPbBuilder; import io.ray.streaming.runtime.rpc.RemoteCallWorker; import io.ray.streaming.runtime.worker.JobWorker; import io.ray.streaming.runtime.worker.context.JobWorkerContext; @@ -42,23 +40,20 @@ public class WorkerLifecycleController { * @return creation result */ private boolean createWorker(ExecutionVertex executionVertex) { - LOG.info("Start to create worker actor for vertex: {} with resource: {}, workeConfig: {}.", - executionVertex.getExecutionVertexName(), executionVertex.getResource(), - executionVertex.getWorkerConfig()); + LOG.info("Start to create worker actor for vertex: {} with resource: {}.", + executionVertex.getExecutionVertexName(), executionVertex.getResource()); Language language = executionVertex.getLanguage(); BaseActorHandle actor; if (Language.JAVA == language) { - actor = Ray.actor(JobWorker::new, executionVertex) + actor = Ray.actor(JobWorker::new) .setResources(executionVertex.getResource()) .setMaxRestarts(-1) .remote(); } else { - RemoteCall.ExecutionVertexContext.ExecutionVertex vertexPb - = new GraphPbBuilder().buildVertex(executionVertex); actor = Ray.actor( - PyActorClass.of("ray.streaming.runtime.worker", "JobWorker"), vertexPb.toByteArray()) + PyActorClass.of("ray.streaming.runtime.worker", "JobWorker")) .setResources(executionVertex.getResource()) .setMaxRestarts(-1) .remote(); @@ -116,20 +111,20 @@ public class WorkerLifecycleController { * @param timeout timeout for waiting, unit: ms * @return starting result */ - public boolean startWorkers(ExecutionGraph executionGraph, long lastCheckpointId, int timeout) { + public boolean startWorkers(ExecutionGraph executionGraph, int timeout) { LOG.info("Begin starting workers."); long startTime = System.currentTimeMillis(); - List> objectRefs = new ArrayList<>(); + List> objectRefs = new ArrayList<>(); // start source actors 1st executionGraph.getSourceActors() - .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId))); + .forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor))); // then start non-source actors executionGraph.getNonSourceActors() - .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId))); + .forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor))); - WaitResult result = Ray.wait(objectRefs, objectRefs.size(), timeout); + WaitResult result = Ray.wait(objectRefs, objectRefs.size(), timeout); if (result.getReady().size() != objectRefs.size()) { LOG.error("Starting workers timeout[{} ms].", timeout); return false; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java deleted file mode 100644 index 5cdba0b0a..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java +++ /dev/null @@ -1,122 +0,0 @@ -package io.ray.streaming.runtime.message; - -import com.google.common.base.MoreObjects; -import java.io.Serializable; - -public class CallResult implements Serializable { - - protected T resultObj; - private boolean success; - private int resultCode; - private String resultMsg; - - public CallResult() { - } - - public CallResult(boolean success, int resultCode, String resultMsg, T resultObj) { - this.success = success; - this.resultCode = resultCode; - this.resultMsg = resultMsg; - this.resultObj = resultObj; - } - - public static CallResult success() { - return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, null); - } - - public static CallResult success(T payload) { - return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, payload); - } - - public static CallResult skipped(String msg) { - return new CallResult<>(true, CallResultEnum.SKIPPED.code, msg, null); - } - - public static CallResult fail() { - return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, null); - } - - public static CallResult fail(T payload) { - return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, payload); - } - - public static CallResult fail(String msg) { - return new CallResult<>(false, CallResultEnum.FAILED.code, msg, null); - } - - public static CallResult fail(CallResultEnum resultEnum, T payload) { - return new CallResult<>(false, resultEnum.code, resultEnum.msg, payload); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("resultObj", resultObj) - .add("success", success) - .add("resultCode", resultCode) - .add("resultMsg", resultMsg) - .toString(); - } - - public boolean isSuccess() { - return this.success; - } - - public void setSuccess(boolean success) { - this.success = success; - } - - public int getResultCode() { - return this.resultCode; - } - - public void setResultCode(int resultCode) { - this.resultCode = resultCode; - } - - public CallResultEnum getResultEnum() { - return CallResultEnum.getEnum(this.resultCode); - } - - public String getResultMsg() { - return this.resultMsg; - } - - public void setResultMsg(String resultMsg) { - this.resultMsg = resultMsg; - } - - public T getResultObj() { - return this.resultObj; - } - - public void setResultObj(T resultObj) { - this.resultObj = resultObj; - } - - public enum CallResultEnum implements Serializable { - /** - * call result enum - */ - SUCCESS(0, "SUCCESS"), - FAILED(1, "FAILED"), - SKIPPED(2, "SKIPPED"); - - public final int code; - public final String msg; - - CallResultEnum(int code, String msg) { - this.code = code; - this.msg = msg; - } - - public static CallResultEnum getEnum(int code) { - for (CallResultEnum value : CallResultEnum.values()) { - if (code == value.code) { - return value; - } - } - return FAILED; - } - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java index f7bbc6278..408397ebb 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java @@ -65,7 +65,7 @@ public class GraphPbBuilder { return builder.build(); } - public RemoteCall.ExecutionVertexContext.ExecutionVertex buildVertex( + private RemoteCall.ExecutionVertexContext.ExecutionVertex buildVertex( ExecutionVertex executionVertex) { // build vertex infos RemoteCall.ExecutionVertexContext.ExecutionVertex.Builder executionVertexBuilder = @@ -79,11 +79,9 @@ public class GraphPbBuilder { ByteString.copyFrom( serializeOperator(executionVertex.getStreamOperator()))); executionVertexBuilder.setChained(isPythonChainedOperator(executionVertex.getStreamOperator())); - if (executionVertex.getWorkerActor() != null) { - executionVertexBuilder.setWorkerActor( - ByteString.copyFrom( - ((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes())); - } + executionVertexBuilder.setWorkerActor( + ByteString.copyFrom( + ((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes())); executionVertexBuilder.setContainerId(executionVertex.getContainerId().toString()); executionVertexBuilder.setBuildTime(executionVertex.getBuildTime()); executionVertexBuilder.setLanguage( diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java deleted file mode 100644 index c0bbcc2c2..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java +++ /dev/null @@ -1,54 +0,0 @@ -package io.ray.streaming.runtime.rpc; - -import com.google.protobuf.InvalidProtocolBufferException; -import io.ray.streaming.runtime.generated.RemoteCall; -import io.ray.streaming.runtime.message.CallResult; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; -import java.util.HashMap; -import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class PbResultParser { - - private static final Logger LOG = LoggerFactory.getLogger(PbResultParser.class); - - public static Boolean parseBoolResult(byte[] result) { - if (null == result) { - LOG.warn("Result is null."); - return false; - } - - RemoteCall.BoolResult boolResult; - try { - boolResult = RemoteCall.BoolResult.parseFrom(result); - } catch (InvalidProtocolBufferException e) { - LOG.error("Parse boolean result has exception.", e); - return false; - } - - return boolResult.getBoolRes(); - } - - public static CallResult parseRollbackResult(byte[] bytes) { - RemoteCall.CallResult callResultPb; - try { - callResultPb = RemoteCall.CallResult.parseFrom(bytes); - } catch (InvalidProtocolBufferException e) { - LOG.error("Rollback parse result has exception.", e); - return CallResult.fail(); - } - - CallResult callResult = new CallResult<>(); - callResult.setSuccess(callResultPb.getSuccess()); - callResult.setResultCode(callResultPb.getResultCode()); - callResult.setResultMsg(callResultPb.getResultMsg()); - RemoteCall.QueueRecoverInfo recoverInfo = callResultPb.getResultObj(); - Map creationStatusMap = new HashMap<>(); - recoverInfo.getCreationStatusMap().forEach((k, v) -> { - creationStatusMap.put(k, ChannelRecoverInfo.ChannelCreationStatus.fromInt(v.getNumber())); - }); - callResult.setResultObj(new ChannelRecoverInfo(creationStatusMap)); - return callResult; - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java deleted file mode 100644 index fe25002bf..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java +++ /dev/null @@ -1,46 +0,0 @@ -package io.ray.streaming.runtime.rpc; - -import com.google.protobuf.Any; -import com.google.protobuf.ByteString; -import io.ray.api.ActorHandle; -import io.ray.api.ObjectRef; -import io.ray.streaming.runtime.generated.RemoteCall; -import io.ray.streaming.runtime.master.JobMaster; -import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; -import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; - -public class RemoteCallMaster { - - public static ObjectRef reportJobWorkerCommitAsync( - ActorHandle actor, - WorkerCommitReport commitReport) { - RemoteCall.WorkerCommitReport commit = RemoteCall.WorkerCommitReport.newBuilder() - .setCommitCheckpointId(commitReport.commitCheckpointId) - .build(); - Any detail = Any.pack(commit); - RemoteCall.BaseWorkerCmd cmd = RemoteCall.BaseWorkerCmd.newBuilder() - .setActorId(ByteString.copyFrom(commitReport.fromActorId.getBytes())) - .setTimestamp(System.currentTimeMillis()) - .setDetail(detail).build(); - - return actor.task(JobMaster::reportJobWorkerCommit, cmd.toByteArray()).remote(); - } - - public static Boolean requestJobWorkerRollback( - ActorHandle actor, - WorkerRollbackRequest rollbackRequest) { - RemoteCall.WorkerRollbackRequest request = RemoteCall.WorkerRollbackRequest.newBuilder() - .setExceptionMsg(rollbackRequest.getRollbackExceptionMsg()) - .setWorkerHostname(rollbackRequest.getHostname()) - .setWorkerPid(rollbackRequest.getPid()).build(); - Any detail = Any.pack(request); - RemoteCall.BaseWorkerCmd cmd = RemoteCall.BaseWorkerCmd.newBuilder() - .setActorId(ByteString.copyFrom(rollbackRequest.fromActorId.getBytes())) - .setTimestamp(System.currentTimeMillis()) - .setDetail(detail).build(); - ObjectRef ret = actor.task( - JobMaster::requestJobWorkerRollback, cmd.toByteArray()).remote(); - byte[] res = ret.get(); - return PbResultParser.parseBoolResult(res); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java index d9b373370..a12dfaea4 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java @@ -4,15 +4,10 @@ import io.ray.api.ActorHandle; import io.ray.api.BaseActorHandle; import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; -import io.ray.api.Ray; import io.ray.api.function.PyActorMethod; -import io.ray.api.function.RayFunc3; -import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.master.JobMaster; import io.ray.streaming.runtime.worker.JobWorker; import io.ray.streaming.runtime.worker.context.JobWorkerContext; -import java.util.ArrayList; -import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,26 +46,19 @@ public class RemoteCallWorker { * Call JobWorker actor to start. * * @param actor target JobWorker actor - * @param checkpointId checkpoint ID to be rollback * @return start result */ - public static ObjectRef rollback(BaseActorHandle actor, final Long checkpointId) { + public static ObjectRef startWorker(BaseActorHandle actor) { LOG.info("Call worker to start, actor: {}.", actor.getId()); - ObjectRef result; + ObjectRef result = null; // python if (actor instanceof PyActorHandle) { - RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() - .setCheckpointId(checkpointId) - .build(); result = ((PyActorHandle) actor) - .task(PyActorMethod.of("rollback"), - checkpointIdPb.toByteArray() - ).remote(); + .task(PyActorMethod.of("start", Boolean.class)).remote(); } else { // java - result = ((ActorHandle) actor) - .task(JobWorker::rollback, checkpointId, System.currentTimeMillis()).remote(); + result = ((ActorHandle) actor).task(JobWorker::start).remote(); } LOG.info("Finished calling worker to start."); @@ -94,92 +82,4 @@ public class RemoteCallWorker { return result; } - public static ObjectRef triggerCheckpoint(BaseActorHandle actor, Long barrierId) { - // python - if (actor instanceof PyActorHandle) { - RemoteCall.Barrier barrierPb = RemoteCall.Barrier.newBuilder().setId(barrierId).build(); - return ((PyActorHandle) actor).task( - PyActorMethod.of("commit"), barrierPb.toByteArray()).remote(); - } else { - // java - return ((ActorHandle) actor).task(JobWorker::triggerCheckpoint, barrierId) - .remote(); - } - } - - public static void clearExpiredCheckpointParallel( - List actors, Long stateCheckpointId, - Long queueCheckpointId) { - if (LOG.isInfoEnabled()) { - LOG.info("Call worker clearExpiredCheckpoint, state checkpoint id is {}," + - " queue checkpoint id is {}.", stateCheckpointId, queueCheckpointId); - } - - List result = - checkpointCompleteCommonCallTwoWay(actors, stateCheckpointId, queueCheckpointId, - "clear_expired_cp", JobWorker::clearExpiredCheckpoint); - - if (LOG.isInfoEnabled()) { - result.forEach( - obj -> LOG.info("Finish call worker clearExpiredCheckpointParallel, ret is {}.", obj)); - } - } - - public static void notifyCheckpointTimeoutParallel( - List actors, - Long checkpointId) { - LOG.info("Call worker notifyCheckpointTimeoutParallel, checkpoint id is {}", checkpointId); - - actors.forEach(actor -> { - if (actor instanceof PyActorHandle) { - RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() - .setCheckpointId(checkpointId) - .build(); - ((PyActorHandle) actor).task(PyActorMethod.of("notify_checkpoint_timeout"), - checkpointIdPb.toByteArray()).remote(); - } else { - ((ActorHandle) actor).task(JobWorker::notifyCheckpointTimeout, checkpointId) - .remote(); - } - }); - - LOG.info("Finish call worker notifyCheckpointTimeoutParallel."); - } - - private static List checkpointCompleteCommonCallTwoWay( - List actors, Long stateCheckpointId, Long queueCheckpointId, - String pyFuncName, RayFunc3 rayFunc) { - List> waitFor = - checkpointCompleteCommonCall(actors, stateCheckpointId, queueCheckpointId, - pyFuncName, rayFunc); - return Ray.get(waitFor); - } - - private static List> checkpointCompleteCommonCall( - List actors, - Long stateCheckpointId, Long queueCheckpointId, - String pyFuncName, - RayFunc3 rayFunc) { - List> waitFor = new ArrayList<>(); - actors.forEach(actor -> { - // python - if (actor instanceof PyActorHandle) { - RemoteCall.CheckpointId stateCheckpointIdPb = RemoteCall.CheckpointId.newBuilder() - .setCheckpointId(stateCheckpointId) - .build(); - - RemoteCall.CheckpointId queueCheckpointIdPb = RemoteCall.CheckpointId.newBuilder() - .setCheckpointId(queueCheckpointId) - .build(); - waitFor.add(((PyActorHandle) actor).task(PyActorMethod.of(pyFuncName), - stateCheckpointIdPb.toByteArray(), queueCheckpointIdPb.toByteArray()).remote()); - } else { - // java - waitFor.add(((ActorHandle) actor).task(rayFunc, stateCheckpointId, queueCheckpointId) - .remote()); - } - }); - return waitFor; - } - } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java deleted file mode 100644 index db7937159..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java +++ /dev/null @@ -1,131 +0,0 @@ -package io.ray.streaming.runtime.rpc.async; - -import io.ray.api.ActorHandle; -import io.ray.api.BaseActorHandle; -import io.ray.api.ObjectRef; -import io.ray.api.PyActorHandle; -import io.ray.api.function.PyActorMethod; -import io.ray.streaming.runtime.generated.RemoteCall; -import io.ray.streaming.runtime.message.CallResult; -import io.ray.streaming.runtime.rpc.PbResultParser; -import io.ray.streaming.runtime.rpc.async.RemoteCallPool.Callback; -import io.ray.streaming.runtime.rpc.async.RemoteCallPool.ExceptionHandler; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; -import io.ray.streaming.runtime.worker.JobWorker; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -@SuppressWarnings("unchecked") -public class AsyncRemoteCaller { - - private static final Logger LOG = LoggerFactory.getLogger(AsyncRemoteCaller.class); - private RemoteCallPool remoteCallPool = new RemoteCallPool(); - - /** - * Call JobWorker::checkIfNeedRollback async - * - * @param actor JobWorker actor - * @param callback callback function on success - * @param onException callback function on exception - */ - public void checkIfNeedRollbackAsync( - BaseActorHandle actor, Callback callback, - ExceptionHandler onException) { - if (actor instanceof PyActorHandle) { - // python - remoteCallPool.bindCallback( - ((PyActorHandle) actor).task(PyActorMethod.of("check_if_need_rollback")).remote(), - (obj) -> { - byte[] res = (byte[]) obj; - callback.handle(PbResultParser.parseBoolResult(res)); - }, onException); - } else { - // java - remoteCallPool.bindCallback( - ((ActorHandle) actor).task(JobWorker::checkIfNeedRollback, - System.currentTimeMillis()).remote(), callback, onException); - } - } - - /** - * Call JobWorker::rollback async - * - * @param actor JobWorker actor - * @param callback callback function on success - * @param onException callback function on exception - */ - public void rollback( - BaseActorHandle actor, - final Long checkpointId, - Callback> callback, - ExceptionHandler onException) { - // python - if (actor instanceof PyActorHandle) { - RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() - .setCheckpointId(checkpointId) - .build(); - ObjectRef call = ((PyActorHandle) actor).task(PyActorMethod.of("rollback"), - checkpointIdPb.toByteArray()).remote(); - remoteCallPool.bindCallback(call, obj -> - callback.handle(PbResultParser.parseRollbackResult((byte[]) obj)), onException); - } else { - // java - ObjectRef call = ((ActorHandle) actor).task( - JobWorker::rollback, checkpointId, System.currentTimeMillis()).remote(); - remoteCallPool.bindCallback(call, obj -> { - CallResult res = (CallResult) obj; - callback.handle(res); - }, onException); - } - } - - /** - * Call JobWorker::rollback async in batch - * - * @param actors JobWorker actor list - * @param callback callback function on success - * @param onException callback function on exception - */ - public void batchRollback( - List actors, final Long checkpointId, - Collection abnormalQueues, - Callback>> callback, - ExceptionHandler onException) { - List> rayCallList = new ArrayList<>(); - Map isPyActor = new HashMap<>(); - for (int i = 0; i < actors.size(); ++i) { - BaseActorHandle actor = actors.get(i); - ObjectRef call; - if (actor instanceof PyActorHandle) { - isPyActor.put(i, true); - RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() - .setCheckpointId(checkpointId) - .build(); - call = ((PyActorHandle) actor).task(PyActorMethod.of("rollback"), - checkpointIdPb.toByteArray()).remote(); - } else { - // java - call = ((ActorHandle) actor).task(JobWorker::rollback, checkpointId, - System.currentTimeMillis()).remote(); - } - rayCallList.add(call); - } - remoteCallPool.bindCallback(rayCallList, objList -> { - List> results = new ArrayList<>(); - for (int i = 0; i < objList.size(); ++i) { - Object obj = objList.get(i); - if (isPyActor.getOrDefault(i, false)) { - results.add(PbResultParser.parseRollbackResult((byte[]) obj)); - } else { - results.add((CallResult) obj); - } - } - callback.handle(results); - }, onException); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java deleted file mode 100644 index 52e9e5651..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java +++ /dev/null @@ -1,189 +0,0 @@ -package io.ray.streaming.runtime.rpc.async; - -import io.ray.api.ObjectRef; -import io.ray.api.Ray; -import io.ray.api.WaitResult; -import java.util.Collections; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.ThreadFactory; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -public class RemoteCallPool implements Runnable { - - private static final Logger LOG = LoggerFactory.getLogger(RemoteCallPool.class); - private static final int WAIT_TIME_MS = 5; - private static final long WARNING_PERIOD = 10000; - private final List pendingObjectBundles = new LinkedList<>(); - private Map> singletonHandlerMap = new ConcurrentHashMap<>(); - private Map>> bundleHandlerMap = - new ConcurrentHashMap<>(); - private Map> bundleExceptionHandlerMap = - new ConcurrentHashMap<>(); - private ThreadPoolExecutor callBackPool = new ThreadPoolExecutor( - 2, Runtime.getRuntime().availableProcessors(), - 1, TimeUnit.MINUTES, new LinkedBlockingQueue<>(), - new CallbackThreadFactory()); - private volatile boolean stop = false; - - public RemoteCallPool() { - Thread t = new Thread(Ray.wrapRunnable(this), "remote-pool-loop"); - t.setUncaughtExceptionHandler((thread, throwable) -> - LOG.error("Error in remote call pool thread.", throwable) - ); - t.start(); - } - - @SuppressWarnings("unchecked") - public void bindCallback( - ObjectRef obj, Callback callback, - ExceptionHandler onException) { - List objectRefList = Collections.singletonList(obj); - RemoteCallBundle bundle = new RemoteCallBundle(objectRefList, - true); - singletonHandlerMap.put(bundle, (Callback) callback); - bundleExceptionHandlerMap.put(bundle, onException); - synchronized (pendingObjectBundles) { - pendingObjectBundles.add(bundle); - } - } - - public void bindCallback( - List> objectBundle, Callback> callback, - ExceptionHandler onException) { - RemoteCallBundle bundle = new RemoteCallBundle(objectBundle, false); - bundleHandlerMap.put(bundle, callback); - bundleExceptionHandlerMap.put(bundle, onException); - synchronized (pendingObjectBundles) { - pendingObjectBundles.add(bundle); - } - } - - public void stop() { - stop = true; - } - - public void run() { - while (!stop) { - try { - if (pendingObjectBundles.isEmpty()) { - Thread.sleep(WAIT_TIME_MS); - continue; - } - synchronized (pendingObjectBundles) { - Iterator itr = pendingObjectBundles.iterator(); - while (itr.hasNext()) { - RemoteCallBundle bundle = itr.next(); - WaitResult waitResult = - Ray.wait(bundle.objects, bundle.objects.size(), WAIT_TIME_MS); - List> readyObjs = waitResult.getReady(); - if (readyObjs.size() != bundle.objects.size()) { - long now = System.currentTimeMillis(); - long waitingTime = now - bundle.createTime; - if (waitingTime > WARNING_PERIOD && now - bundle.lastWarnTs > WARNING_PERIOD) { - bundle.lastWarnTs = now; - LOG.warn("Bundle has being waiting for {} ms, bundle = {}.", waitingTime, bundle); - } - continue; - } - - ExceptionHandler exceptionHandler = bundleExceptionHandlerMap.get(bundle); - if (bundle.isSingletonBundle) { - callBackPool.execute(Ray.wrapRunnable(() -> { - try { - singletonHandlerMap.get(bundle).handle(readyObjs.get(0).get()); - singletonHandlerMap.remove(bundle); - } catch (Throwable th) { - LOG.error("Error when get object, objectId = {}.", readyObjs.get(0).toString(), - th); - if (exceptionHandler != null) { - exceptionHandler.handle(th); - } - } - })); - } else { - List results = - readyObjs.stream().map(ObjectRef::get).collect(Collectors.toList()); - List resultIds = - readyObjs.stream().map(ObjectRef::toString).collect(Collectors.toList()); - callBackPool.execute(Ray.wrapRunnable(() -> { - try { - bundleHandlerMap.get(bundle).handle(results); - bundleHandlerMap.remove(bundle); - } catch (Throwable th) { - LOG.error("Error when get object, objectIds = {}.", resultIds, th); - if (exceptionHandler != null) { - exceptionHandler.handle(th); - } - } - })); - } - itr.remove(); - } - } - - } catch (Exception e) { - LOG.error("Exception in wait loop.", e); - } - } - LOG.info("Wait loop finished."); - } - - @FunctionalInterface - public interface ExceptionHandler { - - void handle(T object); - } - - @FunctionalInterface - public interface Callback { - - void handle(T object) throws Throwable; - } - - private static class RemoteCallBundle { - - List> objects; - boolean isSingletonBundle; - long lastWarnTs = System.currentTimeMillis(); - long createTime = System.currentTimeMillis(); - - RemoteCallBundle(List> objects, boolean isSingletonBundle) { - this.objects = objects; - this.isSingletonBundle = isSingletonBundle; - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("["); - objects.forEach(rayObj -> sb.append(rayObj.toString()).append(",")); - sb.append("]"); - return sb.toString(); - } - } - - static class CallbackThreadFactory implements ThreadFactory { - - private AtomicInteger cnt = new AtomicInteger(0); - - @Override - public Thread newThread(Runnable r) { - Thread t = new Thread(r); - t.setUncaughtExceptionHandler((thread, throwable) -> LOG.error("Callback err.", throwable)); - t.setName("callback-thread-" + cnt.getAndIncrement()); - return t; - } - } - -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelId.java similarity index 97% rename from streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java rename to streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelId.java index 07e98ae3f..75904e19e 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelId.java @@ -1,4 +1,4 @@ -package io.ray.streaming.runtime.transfer.channel; +package io.ray.streaming.runtime.transfer; import com.google.common.base.FinalizablePhantomReference; import com.google.common.base.FinalizableReferenceQueue; @@ -41,6 +41,47 @@ public class ChannelId { this.nativeIdPtr = nativeIdPtr; } + public byte[] getBytes() { + return bytes; + } + + public ByteBuffer getBuffer() { + return buffer; + } + + public long getAddress() { + return address; + } + + public long getNativeIdPtr() { + if (nativeIdPtr == 0) { + throw new IllegalStateException("native ID not available"); + } + return nativeIdPtr; + } + + @Override + public String toString() { + return strId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ChannelId that = (ChannelId) o; + return strId.equals(that.strId); + } + + @Override + public int hashCode() { + return strId.hashCode(); + } + private static native long createNativeId(long idAddress); private static native void destroyNativeId(long nativeIdPtr); @@ -123,7 +164,7 @@ public class ChannelId { * @param id hex string representation of channel id * @return bytes representation of channel id */ - public static byte[] idStrToBytes(String id) { + static byte[] idStrToBytes(String id) { byte[] idBytes = BaseEncoding.base16().decode(id.toUpperCase()); assert idBytes.length == ChannelId.ID_LENGTH; return idBytes; @@ -133,51 +174,10 @@ public class ChannelId { * @param id bytes representation of channel id * @return hex string representation of channel id */ - public static String idBytesToStr(byte[] id) { + static String idBytesToStr(byte[] id) { assert id.length == ChannelId.ID_LENGTH; return BaseEncoding.base16().encode(id).toLowerCase(); } - public byte[] getBytes() { - return bytes; - } - - public ByteBuffer getBuffer() { - return buffer; - } - - public long getAddress() { - return address; - } - - public long getNativeIdPtr() { - if (nativeIdPtr == 0) { - throw new IllegalStateException("native ID not available"); - } - return nativeIdPtr; - } - - @Override - public String toString() { - return strId; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ChannelId that = (ChannelId) o; - return strId.equals(that.strId); - } - - @Override - public int hashCode() { - return strId.hashCode(); - } - } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelUtils.java similarity index 94% rename from streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java rename to streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelUtils.java index 74e813134..c62b21018 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelUtils.java @@ -1,4 +1,4 @@ -package io.ray.streaming.runtime.transfer.channel; +package io.ray.streaming.runtime.transfer; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.generated.Streaming; @@ -10,7 +10,7 @@ public class ChannelUtils { private static final Logger LOGGER = LoggerFactory.getLogger(ChannelUtils.class); - public static byte[] toNativeConf(StreamingWorkerConfig workerConfig) { + static byte[] toNativeConf(StreamingWorkerConfig workerConfig) { Streaming.StreamingConfig.Builder builder = Streaming.StreamingConfig.newBuilder(); // job name diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataMessage.java new file mode 100644 index 000000000..6c8f08d8e --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataMessage.java @@ -0,0 +1,55 @@ +package io.ray.streaming.runtime.transfer; + +import java.nio.ByteBuffer; + +/** + * DataMessage represents data between upstream and downstream operator + */ +public class DataMessage implements Message { + + private final ByteBuffer body; + private final long msgId; + private final long timestamp; + private final String channelId; + + public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) { + this.body = body; + this.timestamp = timestamp; + this.msgId = msgId; + this.channelId = channelId; + } + + @Override + public ByteBuffer body() { + return body; + } + + @Override + public long timestamp() { + return timestamp; + } + + /** + * @return message id + */ + public long msgId() { + return msgId; + } + + /** + * @return string id of channel where data is coming from + */ + public String channelId() { + return channelId; + } + + @Override + public String toString() { + return "DataMessage{" + + "body=" + body + + ", msgId=" + msgId + + ", timestamp=" + timestamp + + ", channelId='" + channelId + '\'' + + '}'; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java index f10571796..3cdf15a07 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java @@ -4,22 +4,11 @@ import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.config.types.TransferChannelType; -import io.ray.streaming.runtime.transfer.channel.ChannelId; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus; -import io.ray.streaming.runtime.transfer.channel.ChannelUtils; -import io.ray.streaming.runtime.transfer.channel.OffsetInfo; -import io.ray.streaming.runtime.transfer.message.BarrierMessage; -import io.ray.streaming.runtime.transfer.message.ChannelMessage; -import io.ray.streaming.runtime.transfer.message.DataMessage; import io.ray.streaming.runtime.util.Platform; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.HashMap; import java.util.LinkedList; import java.util.List; -import java.util.Map; import java.util.Queue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,20 +22,7 @@ public class DataReader { private static final Logger LOG = LoggerFactory.getLogger(DataReader.class); private long nativeReaderPtr; - // params set by getBundleNative: bundle data address + size - private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24); - // We use direct buffer to reduce gc overhead and memory copy. - private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0); - private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH); - - private final Map queueCreationStatusMap = new HashMap<>(); - private Queue buf = new LinkedList<>(); - - { - getBundleParams.order(ByteOrder.nativeOrder()); - bundleData.order(ByteOrder.nativeOrder()); - bundleMeta.order(ByteOrder.nativeOrder()); - } + private Queue buf = new LinkedList<>(); /** * @param inputChannels input channels ids @@ -56,7 +32,6 @@ public class DataReader { public DataReader( List inputChannels, List fromActors, - Map checkpoints, StreamingWorkerConfig workerConfig) { Preconditions.checkArgument(inputChannels.size() > 0); Preconditions.checkArgument(inputChannels.size() == fromActors.size()); @@ -64,16 +39,11 @@ public class DataReader { new ChannelCreationParametersBuilder().buildInputQueueParameters(inputChannels, fromActors); byte[][] inputChannelsBytes = inputChannels.stream() .map(ChannelId::idStrToBytes).toArray(byte[][]::new); - - // get sequence ID and message ID from OffsetInfo + long[] seqIds = new long[inputChannels.size()]; long[] msgIds = new long[inputChannels.size()]; for (int i = 0; i < inputChannels.size(); i++) { - String channelId = inputChannels.get(i); - if (!checkpoints.containsKey(channelId)) { - msgIds[i] = 0; - continue; - } - msgIds[i] = checkpoints.get(inputChannels.get(i)).getStreamingMsgId(); + seqIds[i] = 0; + msgIds[i] = 0; } long timerInterval = workerConfig.transferConfig.readerTimerIntervalMs(); TransferChannelType channelType = workerConfig.transferConfig.channelType(); @@ -81,34 +51,33 @@ public class DataReader { if (TransferChannelType.MEMORY_CHANNEL == channelType) { isMock = true; } + boolean isRecreate = workerConfig.transferConfig.readerIsRecreate(); - // create native reader - List creationStatus = new ArrayList<>(); this.nativeReaderPtr = createDataReaderNative( initialParameters, inputChannelsBytes, + seqIds, msgIds, timerInterval, - creationStatus, + isRecreate, ChannelUtils.toNativeConf(workerConfig), isMock ); - for (int i = 0; i < inputChannels.size(); ++i) { - queueCreationStatusMap - .put(inputChannels.get(i), ChannelCreationStatus.fromInt(creationStatus.get(i))); - } - LOG.info("Create DataReader succeed for worker: {}, creation status={}.", - workerConfig.workerInternalConfig.workerName(), queueCreationStatusMap); + LOG.info("Create DataReader succeed for worker: {}.", + workerConfig.workerInternalConfig.workerName()); } - private static native long createDataReaderNative( - ChannelCreationParametersBuilder initialParameters, - byte[][] inputChannels, - long[] msgIds, - long timerInterval, - List creationStatus, - byte[] configBytes, - boolean isMock); + // params set by getBundleNative: bundle data address + size + private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24); + // We use direct buffer to reduce gc overhead and memory copy. + private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0); + private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH); + + { + getBundleParams.order(ByteOrder.nativeOrder()); + bundleData.order(ByteOrder.nativeOrder()); + bundleMeta.order(ByteOrder.nativeOrder()); + } /** * Read message from input channels, if timeout, return null. @@ -116,21 +85,26 @@ public class DataReader { * @param timeoutMillis timeout * @return message or null */ - public ChannelMessage read(long timeoutMillis) { + public DataMessage read(long timeoutMillis) { if (buf.isEmpty()) { getBundle(timeoutMillis); // if bundle not empty. empty message still has data size + seqId + msgId if (bundleData.position() < bundleData.limit()) { BundleMeta bundleMeta = new BundleMeta(this.bundleMeta); - String channelID = bundleMeta.getChannelID(); - long timestamp = bundleMeta.getBundleTs(); // barrier if (bundleMeta.getBundleType() == DataBundleType.BARRIER) { - buf.offer(getBarrier(bundleData, channelID, timestamp)); + throw new UnsupportedOperationException( + "Unsupported bundle type " + bundleMeta.getBundleType()); } else if (bundleMeta.getBundleType() == DataBundleType.BUNDLE) { + String channelID = bundleMeta.getChannelID(); + long timestamp = bundleMeta.getBundleTs(); for (int i = 0; i < bundleMeta.getMessageListSize(); i++) { buf.offer(getDataMessage(bundleData, channelID, timestamp)); } + } else if (bundleMeta.getBundleType() == DataBundleType.EMPTY) { + long messageId = bundleMeta.getLastMessageId(); + buf.offer(new DataMessage(null, bundleMeta.getBundleTs(), + messageId, bundleMeta.getChannelID())); } } } @@ -140,31 +114,6 @@ public class DataReader { return buf.poll(); } - public ChannelRecoverInfo getQueueRecoverInfo() { - return new ChannelRecoverInfo(queueCreationStatusMap); - } - - private String getQueueIdString(ByteBuffer buffer) { - byte[] bytes = new byte[ChannelId.ID_LENGTH]; - buffer.get(bytes); - return ChannelId.idBytesToStr(bytes); - } - - private BarrierMessage getBarrier(ByteBuffer bundleData, String channelID, long timestamp) { - ByteBuffer offsetsInfoBytes = ByteBuffer.wrap(getOffsetsInfoNative(nativeReaderPtr)); - offsetsInfoBytes.order(ByteOrder.nativeOrder()); - BarrierOffsetInfo offsetInfo = new BarrierOffsetInfo(offsetsInfoBytes); - DataMessage message = getDataMessage(bundleData, channelID, timestamp); - BarrierItem barrierItem = new BarrierItem(message, offsetInfo); - return new BarrierMessage( - message.getMsgId(), - message.getTimestamp(), - message.getChannelId(), - barrierItem.getData(), - barrierItem.getGlobalBarrierId(), - barrierItem.getBarrierOffsetInfo().getQueueOffsetInfo()); - } - private DataMessage getDataMessage(ByteBuffer bundleData, String channelID, long timestamp) { int dataSize = bundleData.getInt(); // msgId @@ -212,14 +161,22 @@ public class DataReader { LOG.info("Finish closing DataReader."); } + private static native long createDataReaderNative( + ChannelCreationParametersBuilder initialParameters, + byte[][] inputChannels, + long[] seqIds, + long[] msgIds, + long timerInterval, + boolean isRecreate, + byte[] configBytes, + boolean isMock); + private native void getBundleNative( long nativeReaderPtr, long timeoutMillis, long params, long metaAddress); - private native byte[] getOffsetsInfoNative(long nativeQueueConsumerPtr); - private native void stopReaderNative(long nativeReaderPtr); private native void closeReaderNative(long nativeReaderPtr); @@ -236,16 +193,7 @@ public class DataReader { } } - public enum BarrierType { - GLOBAL_BARRIER(0); - private int code; - - BarrierType(int code) { - this.code = code; - } - } - - class BundleMeta { + static class BundleMeta { // kMessageBundleHeaderSize + kUniqueIDSize: // magicNum(4b) + bundleTs(8b) + lastMessageId(8b) + messageListSize(4b) @@ -278,7 +226,13 @@ public class DataReader { } // rawBundleSize rawBundleSize = buffer.getInt(); - channelID = getQueueIdString(buffer); + channelID = getQidString(buffer); + } + + private String getQidString(ByteBuffer buffer) { + byte[] bytes = new byte[ChannelId.ID_LENGTH]; + buffer.get(bytes); + return ChannelId.idBytesToStr(bytes); } public int getMagicNum() { @@ -310,73 +264,4 @@ public class DataReader { } } - class BarrierOffsetInfo { - - private int queueSize; - private Map queueOffsetInfo; - - public BarrierOffsetInfo(ByteBuffer buffer) { - // deserialization offset - queueSize = buffer.getInt(); - queueOffsetInfo = new HashMap<>(queueSize); - for (int i = 0; i < queueSize; ++i) { - String qid = getQueueIdString(buffer); - long streamingMsgId = buffer.getLong(); - queueOffsetInfo.put(qid, new OffsetInfo(streamingMsgId)); - } - } - - public int getQueueSize() { - return queueSize; - } - - public Map getQueueOffsetInfo() { - return queueOffsetInfo; - } - } - - class BarrierItem { - - BarrierOffsetInfo barrierOffsetInfo; - private long msgId; - private BarrierType barrierType; - private long globalBarrierId; - private ByteBuffer data; - - public BarrierItem(DataMessage message, BarrierOffsetInfo barrierOffsetInfo) { - this.barrierOffsetInfo = barrierOffsetInfo; - msgId = message.getMsgId(); - ByteBuffer buffer = message.body(); - // c++ use native order, so use native order here. - buffer.order(ByteOrder.nativeOrder()); - int barrierTypeInt = buffer.getInt(); - globalBarrierId = buffer.getLong(); - // dataSize includes: barrier type(32 bit), globalBarrierId, data - data = buffer.slice(); - data.order(ByteOrder.nativeOrder()); - buffer.position(buffer.limit()); - barrierType = BarrierType.GLOBAL_BARRIER; - } - - public long getBarrierMsgId() { - return msgId; - } - - public BarrierType getBarrierType() { - return barrierType; - } - - public long getGlobalBarrierId() { - return globalBarrierId; - } - - public ByteBuffer getData() { - return data; - } - - public BarrierOffsetInfo getBarrierOffsetInfo() { - return barrierOffsetInfo; - } - } - } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java index 55729c7fb..a8cebabb0 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java @@ -4,15 +4,10 @@ import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.config.types.TransferChannelType; -import io.ray.streaming.runtime.transfer.channel.ChannelId; -import io.ray.streaming.runtime.transfer.channel.ChannelUtils; -import io.ray.streaming.runtime.transfer.channel.OffsetInfo; import io.ray.streaming.runtime.util.Platform; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -27,7 +22,6 @@ public class DataWriter { private long nativeWriterPtr; private ByteBuffer buffer = ByteBuffer.allocateDirect(0); private long bufferAddress; - private List outputChannels; { ensureBuffer(0); @@ -37,33 +31,21 @@ public class DataWriter { * @param outputChannels output channels ids * @param toActors downstream output actors * @param workerConfig configuration - * @param checkpoints offset of each channels */ public DataWriter( List outputChannels, List toActors, - Map checkpoints, StreamingWorkerConfig workerConfig) { Preconditions.checkArgument(!outputChannels.isEmpty()); Preconditions.checkArgument(outputChannels.size() == toActors.size()); - this.outputChannels = outputChannels; - ChannelCreationParametersBuilder initialParameters = new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors); - byte[][] outputChannelsBytes = outputChannels.stream() .map(ChannelId::idStrToBytes).toArray(byte[][]::new); long channelSize = workerConfig.transferConfig.channelSize(); - - // load message id from checkpoints long[] msgIds = new long[outputChannels.size()]; for (int i = 0; i < outputChannels.size(); i++) { - String channelId = outputChannels.get(i); - if (!checkpoints.containsKey(channelId)) { - msgIds[i] = 0; - continue; - } - msgIds[i] = checkpoints.get(channelId).getStreamingMsgId(); + msgIds[i] = 0; } TransferChannelType channelType = workerConfig.transferConfig.channelType(); boolean isMock = false; @@ -82,14 +64,6 @@ public class DataWriter { workerConfig.workerInternalConfig.workerName()); } - private static native long createWriterNative( - ChannelCreationParametersBuilder initialParameters, - byte[][] outputQueueIds, - long[] msgIds, - long channelSize, - byte[] confBytes, - boolean isMock); - /** * Write msg into the specified channel * @@ -108,8 +82,9 @@ public class DataWriter { * Write msg into the specified channels * * @param ids channel ids - * @param item message item data section is specified by [position, limit). - * item doesn't have to be a direct buffer. + * @param item message item data section is specified by [position, limit). item doesn't have + * to + * be a direct buffer. */ public void write(Set ids, ByteBuffer item) { int size = item.remaining(); @@ -129,27 +104,6 @@ public class DataWriter { } } - public Map getOutputCheckpoints() { - long[] msgId = getOutputMsgIdNative(nativeWriterPtr); - Map res = new HashMap<>(outputChannels.size()); - for (int i = 0; i < outputChannels.size(); ++i) { - res.put(outputChannels.get(i), new OffsetInfo(msgId[i])); - } - LOG.info("got output points, {}.", res); - return res; - } - - public void broadcastBarrier(long checkpointId, ByteBuffer attach) { - LOG.info("Broadcast barrier, cpId={}.", checkpointId); - Preconditions.checkArgument(attach.order() == ByteOrder.nativeOrder()); - broadcastBarrierNative(nativeWriterPtr, checkpointId, attach.array()); - } - - public void clearCheckpoint(long checkpointId) { - LOG.info("Producer clear checkpoint, checkpointId={}.", checkpointId); - clearCheckpointNative(nativeWriterPtr, checkpointId); - } - /** * stop writer */ @@ -170,6 +124,14 @@ public class DataWriter { LOG.info("Finish closing data writer."); } + private static native long createWriterNative( + ChannelCreationParametersBuilder initialParameters, + byte[][] outputQueueIds, + long[] msgIds, + long channelSize, + byte[] confBytes, + boolean isMock); + private native long writeMessageNative( long nativeQueueProducerPtr, long nativeIdPtr, long address, int size); @@ -177,15 +139,4 @@ public class DataWriter { private native void closeWriterNative(long nativeQueueProducerPtr); - private native long[] getOutputMsgIdNative(long nativeQueueProducerPtr); - - private native void broadcastBarrierNative( - long nativeQueueProducerPtr, long checkpointId, - byte[] data); - - private native void clearCheckpointNative( - long nativeQueueProducerPtr, - long checkpointId - ); - } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/Message.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/Message.java new file mode 100644 index 000000000..f48cb6f77 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/Message.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.transfer; + +import java.nio.ByteBuffer; + +public interface Message { + + /** + * Message data + *

+ * Message body is a direct byte buffer, which may be invalid after call next + * DataReader#getBundleNative. Please consume this buffer fully + * before next call getBundleNative. + * + * @return message body + */ + ByteBuffer body(); + + /** + * @return timestamp when item is written by upstream DataWriter + */ + long timestamp(); +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java deleted file mode 100644 index 584f411ee..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java +++ /dev/null @@ -1,60 +0,0 @@ -package io.ray.streaming.runtime.transfer.channel; - -import com.google.common.base.MoreObjects; -import java.io.Serializable; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -public class ChannelRecoverInfo implements Serializable { - - private static final Logger LOG = LoggerFactory.getLogger(ChannelRecoverInfo.class); - public Map queueCreationStatusMap; - - - public ChannelRecoverInfo(Map queueCreationStatusMap) { - this.queueCreationStatusMap = queueCreationStatusMap; - } - - public Set getDataLostQueues() { - Set dataLostQueues = new HashSet<>(); - queueCreationStatusMap.forEach((q, status) -> { - if (status.equals(ChannelCreationStatus.DataLost)) { - dataLostQueues.add(q); - } - }); - return dataLostQueues; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("dataLostQueues", getDataLostQueues()) - .toString(); - } - - public enum ChannelCreationStatus { - FreshStarted(0), - PullOk(1), - Timeout(2), - DataLost(3); - - private int id; - - ChannelCreationStatus(int id) { - this.id = id; - } - - public static ChannelCreationStatus fromInt(int id) { - for (ChannelCreationStatus status : ChannelCreationStatus.values()) { - if (status.id == id) { - return status; - } - } - return null; - } - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java deleted file mode 100644 index 5c3ea02a7..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java +++ /dev/null @@ -1,31 +0,0 @@ -package io.ray.streaming.runtime.transfer.channel; - -import com.google.common.base.MoreObjects; -import java.io.Serializable; - -/** - * This data structure contains offset used by streaming queue. - */ -public class OffsetInfo implements Serializable { - - private long streamingMsgId; - - public OffsetInfo(long streamingMsgId) { - this.streamingMsgId = streamingMsgId; - } - - public long getStreamingMsgId() { - return streamingMsgId; - } - - public void setStreamingMsgId(long streamingMsgId) { - this.streamingMsgId = streamingMsgId; - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("streamingMsgId", streamingMsgId) - .toString(); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java deleted file mode 100644 index f4d909ce7..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java +++ /dev/null @@ -1,22 +0,0 @@ -package io.ray.streaming.runtime.transfer.exception; - -import io.ray.streaming.runtime.transfer.DataReader; -import io.ray.streaming.runtime.transfer.DataWriter; -import io.ray.streaming.runtime.transfer.channel.ChannelId; -import java.nio.ByteBuffer; - -/** - * when {@link DataReader#stop()} or {@link DataWriter#stop()} is called, this exception might be - * thrown in {@link DataReader#read(long)} and {@link DataWriter#write(ChannelId, ByteBuffer)}, - * which means the read/write operation is failed. - */ -public class ChannelInterruptException extends RuntimeException { - - public ChannelInterruptException() { - super(); - } - - public ChannelInterruptException(String message) { - super(message); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java deleted file mode 100644 index ffc694c53..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java +++ /dev/null @@ -1,34 +0,0 @@ -package io.ray.streaming.runtime.transfer.message; - -import io.ray.streaming.runtime.transfer.channel.OffsetInfo; -import java.nio.ByteBuffer; -import java.util.Map; - - -public class BarrierMessage extends ChannelMessage { - - private final ByteBuffer data; - private final long checkpointId; - private final Map inputOffsets; - - public BarrierMessage( - long msgId, long timestamp, String channelId, - ByteBuffer data, long checkpointId, Map inputOffsets) { - super(msgId, timestamp, channelId); - this.data = data; - this.checkpointId = checkpointId; - this.inputOffsets = inputOffsets; - } - - public ByteBuffer getData() { - return data; - } - - public long getCheckpointId() { - return checkpointId; - } - - public Map getInputOffsets() { - return inputOffsets; - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java deleted file mode 100644 index 6bfa4dca5..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java +++ /dev/null @@ -1,26 +0,0 @@ -package io.ray.streaming.runtime.transfer.message; - -public class ChannelMessage { - - private final long msgId; - private final long timestamp; - private final String channelId; - - public ChannelMessage(long msgId, long timestamp, String channelId) { - this.msgId = msgId; - this.timestamp = timestamp; - this.channelId = channelId; - } - - public long getMsgId() { - return msgId; - } - - public long getTimestamp() { - return timestamp; - } - - public String getChannelId() { - return channelId; - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java deleted file mode 100644 index b3cf779bf..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.ray.streaming.runtime.transfer.message; - -import java.nio.ByteBuffer; - - -/** - * DataMessage represents data between upstream and downstream operators. - */ -public class DataMessage extends ChannelMessage { - - private final ByteBuffer body; - - public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) { - super(msgId, timestamp, channelId); - this.body = body; - } - - public ByteBuffer body() { - return body; - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java deleted file mode 100644 index c32d2ef4f..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java +++ /dev/null @@ -1,59 +0,0 @@ -package io.ray.streaming.runtime.util; - -import io.ray.streaming.runtime.context.ContextBackend; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Handle exception for checkpoint state - */ -public class CheckpointStateUtil { - - private static final Logger LOG = LoggerFactory.getLogger(CheckpointStateUtil.class); - - /** - * DO NOT ALLOW GET EXCEPTION WHEN LOADING CHECKPOINT - * - * @param checkpointState state backend - * @param cpKey checkpoint key - */ - public static byte[] get(ContextBackend checkpointState, String cpKey) { - byte[] val; - try { - val = checkpointState.get(cpKey); - } catch (Exception e) { - throw new CheckpointStateRuntimeException( - String.format("Failed to get %s from state backend.", cpKey), e); - } - return val; - } - - /** - * ALLOW PUT EXCEPTION WHEN SAVING CHECKPOINT - * - * @param checkpointState state backend - * @param key checkpoint key - * @param val checkpoint value - */ - public static void put(ContextBackend checkpointState, String key, byte[] val) { - try { - checkpointState.put(key, val); - } catch (Exception e) { - LOG.error("Failed to put key {} to state backend.", key, e); - } - } - - public static class CheckpointStateRuntimeException extends RuntimeException { - - public CheckpointStateRuntimeException() { - } - - public CheckpointStateRuntimeException(String message) { - super(message); - } - - public CheckpointStateRuntimeException(String message, Throwable cause) { - super(message, cause); - } - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java index 2238e82aa..f5120fb3a 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java @@ -3,29 +3,13 @@ package io.ray.streaming.runtime.util; import io.ray.runtime.RayNativeRuntime; import io.ray.runtime.util.JniUtils; import java.lang.management.ManagementFactory; -import java.net.InetAddress; -import java.net.UnknownHostException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class EnvUtil { - private static final Logger LOG = LoggerFactory.getLogger(EnvUtil.class); - public static String getJvmPid() { return ManagementFactory.getRuntimeMXBean().getName().split("@")[0]; } - public static String getHostName() { - String hostname = ""; - try { - hostname = InetAddress.getLocalHost().getHostName(); - } catch (UnknownHostException e) { - LOG.error("Error occurs while fetching local host.", e); - } - return hostname; - } - public static void loadNativeLibraries() { // Explicitly load `RayNativeRuntime`, to make sure `core_worker_library_java` // is loaded before `streaming_java`. diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java deleted file mode 100644 index 66777b702..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java +++ /dev/null @@ -1,220 +0,0 @@ -package io.ray.streaming.runtime.util; - -import com.sun.management.OperatingSystemMXBean; -import io.ray.api.id.UniqueId; -import io.ray.streaming.runtime.core.resource.Container; -import io.ray.streaming.runtime.core.resource.ContainerId; -import java.io.BufferedInputStream; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.lang.management.ManagementFactory; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -/** - * Resource Utility collects current OS and JVM resource usage information - */ -public class ResourceUtil { - - public static final Logger LOG = LoggerFactory.getLogger(ResourceUtil.class); - - /** - * Refer to: https://docs.oracle.com/javase/8/docs/jre/api/management/extension/com/sun/management/OperatingSystemMXBean.html - */ - private static OperatingSystemMXBean osmxb = - (OperatingSystemMXBean) ManagementFactory.getOperatingSystemMXBean(); - - /** - * Log current jvm process's memory detail - */ - public static void logProcessMemoryDetail() { - int mb = 1024 * 1024; - - //Getting the runtime reference from system - Runtime runtime = Runtime.getRuntime(); - - StringBuilder sb = new StringBuilder(32); - - sb.append("used memory: ").append((runtime.totalMemory() - runtime.freeMemory()) / mb) - .append(", free memory: ").append(runtime.freeMemory() / mb) - .append(", total memory: ").append(runtime.totalMemory() / mb) - .append(", max memory: ").append(runtime.maxMemory() / mb); - - if (LOG.isInfoEnabled()) { - LOG.info(sb.toString()); - } - } - - /** - * @return jvm heap usage ratio. note that one of the survivor space is not include in total - * memory while calculating this ratio. - */ - public static double getJvmHeapUsageRatio() { - Runtime runtime = Runtime.getRuntime(); - return (runtime.totalMemory() - runtime.freeMemory()) * 1.0 / runtime.maxMemory(); - } - - /** - * @return jvm heap usage(in bytes). - * note that this value doesn't include one of the survivor space. - */ - public static long getJvmHeapUsageInBytes() { - Runtime runtime = Runtime.getRuntime(); - return runtime.totalMemory() - runtime.freeMemory(); - } - - /** - * @return the total amount of physical memory in bytes. - */ - public static long getSystemTotalMemory() { - return osmxb.getTotalPhysicalMemorySize(); - } - - /** - * @return the used system physical memory in bytes - */ - public static long getSystemMemoryUsage() { - long totalMemory = osmxb.getTotalPhysicalMemorySize(); - long freeMemory = osmxb.getFreePhysicalMemorySize(); - return totalMemory - freeMemory; - } - - /** - * @return the ratio of used system physical memory. This value is a double in the [0.0,1.0] - */ - public static double getSystemMemoryUsageRatio() { - double totalMemory = osmxb.getTotalPhysicalMemorySize(); - double freeMemory = osmxb.getFreePhysicalMemorySize(); - double ratio = freeMemory / totalMemory; - return 1 - ratio; - } - - /** - * @return the cpu load for current jvm process. This value is a double in the [0.0,1.0] - */ - public static double getProcessCpuUsage() { - return osmxb.getProcessCpuLoad(); - } - - /** - * @return the system cpu usage. - * This value is a double in the [0.0,1.0] - * We will try to use `vsar` to get cpu usage by default, - * and use MXBean if any exception raised. - */ - public static double getSystemCpuUsage() { - double cpuUsage = 0.0; - try { - cpuUsage = getSystemCpuUtilByVsar(); - } catch (Exception e) { - cpuUsage = getSystemCpuUtilByMXBean(); - } - return cpuUsage; - } - - /** - * Returns the "recent cpu usage" for the whole system. This value is a double in the [0.0,1.0] - * interval. A value of 0.0 means that all CPUs were idle during the recent period of time - * observed, while a value of 1.0 means that all CPUs were actively running 100% of the time - * during the recent period being observed - */ - public static double getSystemCpuUtilByMXBean() { - return osmxb.getSystemCpuLoad(); - } - - /** - * Get system cpu util by vsar - */ - public static double getSystemCpuUtilByVsar() throws Exception { - double cpuUsageFromVsar = 0.0; - String[] vsarCpuCommand = {"/bin/sh", "-c", "vsar --check --cpu -s util"}; - try { - Process proc = Runtime.getRuntime().exec(vsarCpuCommand); - BufferedInputStream bis = new BufferedInputStream(proc.getInputStream()); - BufferedReader br = new BufferedReader(new InputStreamReader(bis)); - String line; - List processPidList = new ArrayList<>(); - while ((line = br.readLine()) != null) { - processPidList.add(line); - } - if (!processPidList.isEmpty()) { - String[] split = processPidList.get(0).split("="); - cpuUsageFromVsar = Double.parseDouble(split[1]) / 100.0D; - } else { - throw new IOException("Vsar check cpu usage failed, maybe vsar is not installed."); - } - } catch (Exception e) { - LOG.warn("Failed to get cpu usage by vsar.", e); - throw e; - } - return cpuUsageFromVsar; - } - - /** - * @returns the system load average for the last minute - */ - public static double getSystemLoadAverage() { - return osmxb.getSystemLoadAverage(); - } - - /** - * @return system cpu cores num - */ - public static int getCpuCores() { - return osmxb.getAvailableProcessors(); - } - - /** - * Get containers by hostname of address - * - * @param containers container list - * @param containerHosts container hostname or address set - * @return matched containers - */ - public static List getContainersByHostname( - List containers, - Collection containerHosts) { - - return containers.stream() - .filter(container -> - containerHosts.contains(container.getHostname()) || - containerHosts.contains(container.getAddress())) - .collect(Collectors.toList()); - } - - /** - * Get container by hostname - * - * @param hostName container hostname - * @return container - */ - public static Optional getContainerByHostname( - List containers, - String hostName) { - return containers.stream() - .filter(container -> container.getHostname().equals(hostName) || - container.getAddress().equals(hostName)) - .findFirst(); - } - - /** - * Get container by id - * - * @param containerID container id - * @return container - */ - public static Optional getContainerById( - List containers, - ContainerId containerID) { - return containers.stream() - .filter(container -> container.getId().equals(containerID)) - .findFirst(); - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java deleted file mode 100644 index 420215df1..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java +++ /dev/null @@ -1,15 +0,0 @@ -package io.ray.streaming.runtime.util; - -import io.ray.runtime.serializer.FstSerializer; - -public class Serializer { - - public static byte[] encode(Object obj) { - return FstSerializer.encode(obj); - } - - public static T decode(byte[] bytes) { - return FstSerializer.decode(bytes); - } - -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java index 7aac6b0c6..26a71453b 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java @@ -1,32 +1,20 @@ package io.ray.streaming.runtime.worker; -import io.ray.api.Ray; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.config.types.TransferChannelType; -import io.ray.streaming.runtime.context.ContextBackend; -import io.ray.streaming.runtime.context.ContextBackendFactory; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.core.processor.OneInputProcessor; import io.ray.streaming.runtime.core.processor.ProcessBuilder; import io.ray.streaming.runtime.core.processor.SourceProcessor; import io.ray.streaming.runtime.core.processor.StreamProcessor; import io.ray.streaming.runtime.master.JobMaster; -import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; -import io.ray.streaming.runtime.message.CallResult; -import io.ray.streaming.runtime.rpc.RemoteCallMaster; import io.ray.streaming.runtime.transfer.TransferHandler; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus; -import io.ray.streaming.runtime.util.CheckpointStateUtil; import io.ray.streaming.runtime.util.EnvUtil; -import io.ray.streaming.runtime.util.Serializer; import io.ray.streaming.runtime.worker.context.JobWorkerContext; import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask; import io.ray.streaming.runtime.worker.tasks.SourceStreamTask; import io.ray.streaming.runtime.worker.tasks.StreamTask; import java.io.Serializable; -import java.util.concurrent.atomic.AtomicBoolean; -import org.apache.commons.lang3.exception.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,223 +36,90 @@ public class JobWorker implements Serializable { EnvUtil.loadNativeLibraries(); } - public final Object initialStateChangeLock = new Object(); - /** - * isRecreate=true means this worker is initialized more than once after actor created. - */ - public AtomicBoolean isRecreate = new AtomicBoolean(false); - public ContextBackend contextBackend; private JobWorkerContext workerContext; private ExecutionVertex executionVertex; private StreamingWorkerConfig workerConfig; - /** - * The while-loop thread to read message, process message, and write results - */ + private StreamTask task; - /** - * transferHandler handles messages by ray direct call - */ private TransferHandler transferHandler; - /** - * A flag to avoid duplicated rollback. Becomes true after requesting - * rollback, set to false when finish rollback. - */ - private boolean isNeedRollback = false; - private int rollbackCount = 0; - public JobWorker(ExecutionVertex executionVertex) { - LOG.info("Creating job worker."); - - // TODO: the following 3 lines is duplicated with that in init(), try to optimise it later. - this.executionVertex = executionVertex; - this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); - this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig); - - LOG.info("Ray.getRuntimeContext().wasCurrentActorRestarted()={}", - Ray.getRuntimeContext().wasCurrentActorRestarted()); - if (!Ray.getRuntimeContext().wasCurrentActorRestarted()) { - saveContext(); - LOG.info("Job worker is fresh started, init success."); - return; - } - - LOG.info("Begin load job worker checkpoint state."); - - byte[] bytes = CheckpointStateUtil.get(contextBackend, getJobWorkerContextKey()); - if (bytes != null) { - JobWorkerContext context = Serializer.decode(bytes); - LOG.info("Worker recover from checkpoint state, byte len={}, context={}.", bytes.length, - context); - init(context); - requestRollback("LoadCheckpoint request rollback in new actor."); - } else { - LOG.error( - "Worker is reconstructed, but can't load checkpoint. " + - "Check whether you checkpoint state is reliable. Current checkpoint state is {}.", - contextBackend.getClass().getName()); - } - } - - public synchronized void saveContext() { - byte[] contextBytes = Serializer.encode(workerContext); - String key = getJobWorkerContextKey(); - LOG.info("Saving context, worker context={}, serialized byte length={}, key={}.", workerContext, - contextBytes.length, key); - CheckpointStateUtil.put(contextBackend, key, contextBytes); + public JobWorker() { + LOG.info("Creating job worker succeeded."); } /** * Initialize JobWorker and data communication pipeline. */ public Boolean init(JobWorkerContext workerContext) { - // IMPORTANT: some test cases depends on this log to find workers' pid, - // be careful when changing this log. - LOG.info("Initiating job worker: {}. Worker context is: {}, pid={}.", - workerContext.getWorkerName(), workerContext, EnvUtil.getJvmPid()); - - this.workerContext = workerContext; - this.executionVertex = workerContext.getExecutionVertex(); - this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); - // init state backend - this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig); - - LOG.info("Initiating job worker succeeded: {}.", workerContext.getWorkerName()); - saveContext(); - return true; - } - - /** - * Start worker's stream tasks with specific checkpoint ID. - * - * @return a {@link CallResult} with {@link ChannelRecoverInfo}, - * contains {@link ChannelCreationStatus} of each input queue. - */ - public CallResult rollback(Long checkpointId, Long startRollbackTs) { - synchronized (initialStateChangeLock) { - if (task != null && task.isAlive() && checkpointId == task.lastCheckpointId && - task.isInitialState) { - return CallResult.skipped("Task is already in initial state, skip this rollback."); - } - } - long remoteCallCost = System.currentTimeMillis() - startRollbackTs; - - LOG.info("Start rollback[{}], checkpoint is {}, remote call cost {}ms.", - executionVertex.getExecutionJobVertexName(), checkpointId, remoteCallCost); - - rollbackCount++; - if (rollbackCount > 1) { - isRecreate.set(true); - } + LOG.info("Initiating job worker: {}. Worker context is: {}.", + workerContext.getWorkerName(), workerContext); try { + this.workerContext = workerContext; + this.executionVertex = workerContext.getExecutionVertex(); + this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); + //Init transfer TransferChannelType channelType = workerConfig.transferConfig.channelType(); if (TransferChannelType.NATIVE_CHANNEL == channelType) { transferHandler = new TransferHandler(); } - if (task != null) { - // make sure the task is closed - task.close(); - task = null; - } - // create stream task - task = createStreamTask(checkpointId); - ChannelRecoverInfo channelRecoverInfo = task.recover(isRecreate.get()); - isNeedRollback = false; - - LOG.info("Rollback job worker success, checkpoint is {}, channelRecoverInfo is {}.", - checkpointId, channelRecoverInfo); - - return CallResult.success(channelRecoverInfo); + task = createStreamTask(); + if (task == null) { + return false; + } } catch (Exception e) { - LOG.error("Rollback job worker has exception.", e); - return CallResult.fail(ExceptionUtils.getStackTrace(e)); + LOG.error("Failed to initiate job worker.", e); + return false; } + LOG.info("Initiating job worker succeeded: {}.", workerContext.getWorkerName()); + return true; + } + + /** + * Start worker's stream tasks. + * + * @return result + */ + public Boolean start() { + try { + task.start(); + } catch (Exception e) { + LOG.error("Start worker [{}] occur error.", executionVertex.getExecutionVertexName(), e); + return false; + } + return true; } /** * Create tasks based on the processor corresponding of the operator. */ - private StreamTask createStreamTask(long checkpointId) { - StreamTask task; + private StreamTask createStreamTask() { + StreamTask task = null; StreamProcessor streamProcessor = ProcessBuilder .buildProcessor(executionVertex.getStreamOperator()); LOG.debug("Stream processor created: {}.", streamProcessor); - if (streamProcessor instanceof SourceProcessor) { - task = new SourceStreamTask(streamProcessor, this, checkpointId); - } else if (streamProcessor instanceof OneInputProcessor) { - task = new OneInputStreamTask(streamProcessor, this, checkpointId); - } else { - throw new RuntimeException("Unsupported processor type:" + streamProcessor); + try { + if (streamProcessor instanceof SourceProcessor) { + task = new SourceStreamTask(getTaskId(), streamProcessor, this); + } else if (streamProcessor instanceof OneInputProcessor) { + task = new OneInputStreamTask(getTaskId(), streamProcessor, this); + } else { + throw new RuntimeException("Unsupported processor type:" + streamProcessor); + } + } catch (Exception e) { + LOG.info("Failed to create stream task.", e); + return task; } LOG.info("Stream task created: {}.", task); return task; } - // ---------------------------------------------------------------------- - // Checkpoint - // ---------------------------------------------------------------------- - - /** - * Trigger source job worker checkpoint - */ - public Boolean triggerCheckpoint(Long barrierId) { - LOG.info("Receive trigger, barrierId is {}.", barrierId); - if (task != null) { - return task.triggerCheckpoint(barrierId); - } - return false; - } - - public Boolean notifyCheckpointTimeout(Long checkpointId) { - LOG.info("Notify checkpoint timeout, checkpoint id is {}.", checkpointId); - if (task != null) { - task.notifyCheckpointTimeout(checkpointId); - } - return true; - } - - public Boolean clearExpiredCheckpoint(Long expiredStateCpId, Long expiredQueueCpId) { - LOG.info("Clear expired checkpoint state, checkpoint id is {}; " + - "Clear expired queue msg, checkpoint id is {}", - expiredStateCpId, expiredQueueCpId); - if (task != null) { - if (expiredStateCpId > 0) { - task.clearExpiredCpState(expiredStateCpId); - } - task.clearExpiredQueueMsg(expiredQueueCpId); - } - return true; - } - - // ---------------------------------------------------------------------- - // Failover - // ---------------------------------------------------------------------- - public void requestRollback(String exceptionMsg) { - LOG.info("Request rollback."); - isNeedRollback = true; - isRecreate.set(true); - boolean requestRet = RemoteCallMaster.requestJobWorkerRollback( - workerContext.getMaster(), new WorkerRollbackRequest( - workerContext.getWorkerActorId(), - exceptionMsg, - EnvUtil.getHostName(), - EnvUtil.getJvmPid() - )); - if (!requestRet) { - LOG.warn("Job worker request rollback failed! exceptionMsg={}.", exceptionMsg); - } - } - - public Boolean checkIfNeedRollback(Long startCallTs) { - // No save checkpoint in this query. - long remoteCallCost = System.currentTimeMillis() - startCallTs; - LOG.info("Finished checking if need to rollback with result: {}, rpc delay={}ms.", - isNeedRollback, remoteCallCost); - return isNeedRollback; + public int getTaskId() { + return executionVertex.getExecutionVertexId(); } public StreamingWorkerConfig getWorkerConfig() { @@ -283,19 +138,11 @@ public class JobWorker implements Serializable { return task; } - private String getJobWorkerContextKey() { - return workerConfig.checkpointConfig.jobWorkerContextCpPrefixKey() - + workerConfig.commonConfig.jobName() - + "_" + executionVertex.getExecutionVertexId(); - } - /** * Used by upstream streaming queue to send data to this actor */ public void onReaderMessage(byte[] buffer) { - if (transferHandler != null) { - transferHandler.onReaderMessage(buffer); - } + transferHandler.onReaderMessage(buffer); } /** @@ -312,9 +159,7 @@ public class JobWorker implements Serializable { * Used by downstream streaming queue to send data to this actor */ public void onWriterMessage(byte[] buffer) { - if (transferHandler != null) { - transferHandler.onWriterMessage(buffer); - } + transferHandler.onWriterMessage(buffer); } /** @@ -327,5 +172,4 @@ public class JobWorker implements Serializable { } return transferHandler.onWriterMessageSync(buffer); } - } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java index e4fd3b992..495a2b187 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java @@ -3,9 +3,7 @@ package io.ray.streaming.runtime.worker.context; import com.google.common.base.MoreObjects; import com.google.protobuf.ByteString; import io.ray.api.ActorHandle; -import io.ray.api.id.ActorId; import io.ray.runtime.actor.NativeActorHandle; -import io.ray.streaming.runtime.config.global.CommonConfig; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.master.JobMaster; @@ -35,10 +33,6 @@ public class JobWorkerContext implements Serializable { this.executionVertex = executionVertex; } - public ActorId getWorkerActorId() { - return executionVertex.getWorkerActorId(); - } - public int getWorkerId() { return executionVertex.getExecutionVertexId(); } @@ -59,14 +53,6 @@ public class JobWorkerContext implements Serializable { return executionVertex; } - public Map getConf() { - return getExecutionVertex().getWorkerConfig(); - } - - public String getJobName() { - return getConf().get(CommonConfig.JOB_NAME); - } - @Override public String toString() { return MoreObjects.toStringHelper(this) diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java index eeddf13e5..9ce0c5fb7 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java @@ -2,31 +2,22 @@ package io.ray.streaming.runtime.worker.tasks; import com.google.common.base.MoreObjects; import io.ray.streaming.runtime.core.processor.Processor; -import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.serialization.CrossLangSerializer; import io.ray.streaming.runtime.serialization.JavaSerializer; import io.ray.streaming.runtime.serialization.Serializer; -import io.ray.streaming.runtime.transfer.channel.OffsetInfo; -import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException; -import io.ray.streaming.runtime.transfer.message.BarrierMessage; -import io.ray.streaming.runtime.transfer.message.ChannelMessage; -import io.ray.streaming.runtime.transfer.message.DataMessage; +import io.ray.streaming.runtime.transfer.Message; import io.ray.streaming.runtime.worker.JobWorker; -import java.util.Map; -import org.apache.commons.lang3.exception.ExceptionUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public abstract class InputStreamTask extends StreamTask { - private static final Logger LOG = LoggerFactory.getLogger(InputStreamTask.class); - + private volatile boolean running = true; + private volatile boolean stopped = false; + private long readTimeoutMillis; private final io.ray.streaming.runtime.serialization.Serializer javaSerializer; private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer; - private final long readTimeoutMillis; - public InputStreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) { - super(processor, jobWorker, lastCheckpointId); + public InputStreamTask(int taskId, Processor processor, JobWorker jobWorker) { + super(taskId, processor, jobWorker); readTimeoutMillis = jobWorker.getWorkerConfig().transferConfig.readerTimerIntervalMs(); javaSerializer = new JavaSerializer(); crossLangSerializer = new CrossLangSerializer(); @@ -38,64 +29,35 @@ public abstract class InputStreamTask extends StreamTask { @Override public void run() { - try { - while (running) { - ChannelMessage item; - - // reader.read() will change the consumer state once it got an item. This lock is to - // ensure worker can get correct isInitialState value in exactly-once-mode's rollback. - synchronized (jobWorker.initialStateChangeLock) { - item = reader.read(readTimeoutMillis); - if (item != null) { - isInitialState = false; - } else { - continue; - } + while (running) { + Message item = reader.read(readTimeoutMillis); + if (item != null) { + byte[] bytes = new byte[item.body().remaining() - 1]; + byte typeId = item.body().get(); + item.body().get(bytes); + Object obj; + if (typeId == Serializer.JAVA_TYPE_ID) { + obj = javaSerializer.deserialize(bytes); + } else { + obj = crossLangSerializer.deserialize(bytes); } - - if (item instanceof DataMessage) { - DataMessage dataMessage = (DataMessage) item; - byte[] bytes = new byte[dataMessage.body().remaining() - 1]; - byte typeId = dataMessage.body().get(); - dataMessage.body().get(bytes); - Object obj; - if (typeId == Serializer.JAVA_TYPE_ID) { - obj = javaSerializer.deserialize(bytes); - } else { - obj = crossLangSerializer.deserialize(bytes); - } - processor.process(obj); - } else if (item instanceof BarrierMessage) { - final BarrierMessage queueBarrier = (BarrierMessage) item; - byte[] barrierData = new byte[queueBarrier.getData().remaining()]; - queueBarrier.getData().get(barrierData); - RemoteCall.Barrier barrierPb = RemoteCall.Barrier.parseFrom(barrierData); - final long checkpointId = barrierPb.getId(); - LOG.info("Start to do checkpoint {}, worker name is {}.", checkpointId, - jobWorker.getWorkerContext().getWorkerName()); - - final Map inputPoints = queueBarrier.getInputOffsets(); - doCheckpoint(checkpointId, inputPoints); - LOG.info("Do checkpoint {} success.", checkpointId); - } - } - } catch (Throwable throwable) { - if (throwable instanceof ChannelInterruptException || - ExceptionUtils.getRootCause(throwable) instanceof ChannelInterruptException) { - LOG.info("queue has stopped."); - } else { - // error occurred, need to rollback - LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, throwable); - requestRollback(ExceptionUtils.getStackTrace(throwable)); + processor.process(obj); } } - LOG.info("Input stream task thread exit."); stopped = true; } + @Override + protected void cancelTask() throws Exception { + running = false; + while (!stopped) { + } + } + @Override public String toString() { return MoreObjects.toStringHelper(this) + .add("taskId", taskId) .add("processor", processor) .toString(); } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java index 8eaf2ef66..16293f9ae 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java @@ -8,7 +8,7 @@ import io.ray.streaming.runtime.worker.JobWorker; */ public class OneInputStreamTask extends InputStreamTask { - public OneInputStreamTask(Processor inputProcessor, JobWorker jobWorker, long lastCheckpointId) { - super(inputProcessor, jobWorker, lastCheckpointId); + public OneInputStreamTask(int taskId, Processor inputProcessor, JobWorker jobWorker) { + super(taskId, inputProcessor, jobWorker); } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java index 9fc94c06d..3c70ece44 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java @@ -3,10 +3,7 @@ package io.ray.streaming.runtime.worker.tasks; import io.ray.streaming.operator.SourceOperator; import io.ray.streaming.runtime.core.processor.Processor; import io.ray.streaming.runtime.core.processor.SourceProcessor; -import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException; import io.ray.streaming.runtime.worker.JobWorker; -import java.util.concurrent.atomic.AtomicReference; -import org.apache.commons.lang3.exception.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -16,19 +13,12 @@ public class SourceStreamTask extends StreamTask { private final SourceProcessor sourceProcessor; - - /** - * The pending barrier ID to be triggered. - */ - private final AtomicReference pendingBarrier = new AtomicReference<>(); - private long lastCheckpointId = 0; - /** * SourceStreamTask for executing a {@link SourceOperator}. It is responsible for running the * corresponding source operator. */ - public SourceStreamTask(Processor sourceProcessor, JobWorker jobWorker, long lastCheckpointId) { - super(sourceProcessor, jobWorker, lastCheckpointId); + public SourceStreamTask(int taskId, Processor sourceProcessor, JobWorker jobWorker) { + super(taskId, sourceProcessor, jobWorker); this.sourceProcessor = (SourceProcessor) processor; } @@ -39,48 +29,12 @@ public class SourceStreamTask extends StreamTask { @Override public void run() { LOG.info("Source stream task thread start."); - Long barrierId; - try { - while (running) { - isInitialState = false; - // check checkpoint - barrierId = pendingBarrier.get(); - if (barrierId != null) { - // Important: because cp maybe timeout, master will use the old checkpoint id again - if (pendingBarrier.compareAndSet(barrierId, null)) { - // source fetcher only have outputPoints - LOG.info("Start to do checkpoint {}, worker name is {}.", - barrierId, jobWorker.getWorkerContext().getWorkerName()); - - doCheckpoint(barrierId, null); - - LOG.info("Finish to do checkpoint {}.", barrierId); - } else { - // pendingCheckpointId has modify, should not happen - LOG.warn("Pending checkpointId modify unexpected, expect={}, now={}.", barrierId, - pendingBarrier.get()); - } - } - - sourceProcessor.fetch(); - } - } catch (Throwable e) { - if (e instanceof ChannelInterruptException || - ExceptionUtils.getRootCause(e) instanceof ChannelInterruptException) { - LOG.info("queue has stopped."); - } else { - // occur error, need to rollback - LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, e); - requestRollback(ExceptionUtils.getStackTrace(e)); - } - } - - LOG.info("Source stream task thread exit."); + sourceProcessor.run(); } @Override - public boolean triggerCheckpoint(Long barrierId) { - return pendingBarrier.compareAndSet(null, barrierId); + protected void cancelTask() { } + } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java index 78ef0dbd4..79ad0100d 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java @@ -6,103 +6,53 @@ import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.api.partition.Partition; import io.ray.streaming.runtime.config.worker.WorkerInternalConfig; -import io.ray.streaming.runtime.context.ContextBackend; -import io.ray.streaming.runtime.context.OperatorCheckpointInfo; import io.ray.streaming.runtime.core.collector.OutputCollector; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; -import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.core.processor.Processor; -import io.ray.streaming.runtime.generated.RemoteCall; -import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; -import io.ray.streaming.runtime.rpc.RemoteCallMaster; +import io.ray.streaming.runtime.transfer.ChannelId; import io.ray.streaming.runtime.transfer.DataReader; import io.ray.streaming.runtime.transfer.DataWriter; -import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; -import io.ray.streaming.runtime.transfer.channel.OffsetInfo; -import io.ray.streaming.runtime.util.CheckpointStateUtil; -import io.ray.streaming.runtime.util.Serializer; import io.ray.streaming.runtime.worker.JobWorker; -import io.ray.streaming.runtime.worker.context.JobWorkerContext; import io.ray.streaming.runtime.worker.context.StreamingRuntimeContext; -import java.io.Serializable; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * {@link StreamTask} is a while-loop thread to read message, process message, and send result - * messages to downstream operators - */ public abstract class StreamTask implements Runnable { private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); - private final ContextBackend checkpointState; - public volatile boolean isInitialState = true; - public long lastCheckpointId; + + protected int taskId; protected Processor processor; protected JobWorker jobWorker; protected DataReader reader; - protected DataWriter writer; + List collectors = new ArrayList<>(); + protected volatile boolean running = true; protected volatile boolean stopped = false; - List collectors = new ArrayList<>(); - private Set outdatedCheckpoints = new HashSet<>(); + private Thread thread; - protected StreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) { + protected StreamTask(int taskId, Processor processor, JobWorker jobWorker) { + this.taskId = taskId; this.processor = processor; this.jobWorker = jobWorker; - this.checkpointState = jobWorker.contextBackend; - this.lastCheckpointId = lastCheckpointId; + prepareTask(); this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName() + "-" + System.currentTimeMillis()); this.thread.setDaemon(true); } - public ChannelRecoverInfo recover(boolean isRecover) { - - if (isRecover) { - LOG.info("Stream task begin recover."); - } else { - LOG.info("Stream task first start begin."); - } - prepareTask(isRecover); - - // start runner - ChannelRecoverInfo recoverInfo = new ChannelRecoverInfo(new HashMap<>()); - if (reader != null) { - recoverInfo = reader.getQueueRecoverInfo(); - } - - thread.setUncaughtExceptionHandler( - (t, e) -> LOG.error("Uncaught exception in runner thread.", e)); - LOG.info("Start stream task: {}.", this.getClass().getSimpleName()); - thread.start(); - - if (isRecover) { - LOG.info("Stream task recover end."); - } else { - LOG.info("Stream task first start finished."); - } - - return recoverInfo; - } - /** - * Load checkpoint and build upstream and downstream data transmission - * channels according to {@link ExecutionVertex}. + * Build upstream and downstream data transmission channels according to {@link ExecutionVertex}. */ - private void prepareTask(boolean isRecreate) { - LOG.info("Preparing stream task, isRecreate={}.", isRecreate); + private void prepareTask() { + LOG.debug("Preparing stream task."); ExecutionVertex executionVertex = jobWorker.getExecutionVertex(); // set vertex info into config for native using @@ -111,92 +61,73 @@ public abstract class StreamTask implements Runnable { jobWorker.getWorkerConfig().workerInternalConfig.setProperty( WorkerInternalConfig.OP_NAME_INTERNAL, executionVertex.getExecutionJobVertexName()); - OperatorCheckpointInfo operatorCheckpointInfo = new OperatorCheckpointInfo(); - byte[] bytes = null; + // producer - // Fetch checkpoint from storage only in recreate mode not for new startup worker - // in rescaling or something like that. - if (isRecreate) { - String cpKey = genOpCheckpointKey(lastCheckpointId); - LOG.info("Getting task checkpoints from state, cpKey={}, checkpointId={}.", cpKey, - lastCheckpointId); - bytes = CheckpointStateUtil.get(checkpointState, cpKey); - if (bytes == null) { - String msg = String.format("Task recover failed, checkpoint is null! cpKey=%s", cpKey); - throw new RuntimeException(msg); - } - } - - // when use memory state, if actor throw exception, will miss state - if (bytes != null) { - operatorCheckpointInfo = Serializer.decode(bytes); - processor.loadCheckpoint(operatorCheckpointInfo.processorCheckpoint); - LOG.info( - "Stream task recover from checkpoint state, checkpoint bytes len={}, checkpointInfo={}.", - bytes.length, operatorCheckpointInfo); - } - - // writer - if (!executionVertex.getOutputEdges().isEmpty()) { - LOG.info("Register queue writer, channels={}, outputCheckpoints={}.", - executionVertex.getOutputChannelIdList(), operatorCheckpointInfo.outputPoints); - writer = new DataWriter( - executionVertex.getOutputChannelIdList(), - executionVertex.getOutputActorList(), - operatorCheckpointInfo.outputPoints, - jobWorker.getWorkerConfig() - ); - } - - // reader - if (!executionVertex.getInputEdges().isEmpty()) { - LOG.info("Register queue reader, channels={}, inputCheckpoints={}.", - executionVertex.getInputChannelIdList(), operatorCheckpointInfo.inputPoints); - reader = new DataReader( - executionVertex.getInputChannelIdList(), - executionVertex.getInputActorList(), - operatorCheckpointInfo.inputPoints, - jobWorker.getWorkerConfig() - ); - } - - openProcessor(); - - LOG.debug("Finished preparing stream task."); - } - - /** - * Create one collector for each distinct output operator(i.e. each {@link ExecutionJobVertex}) - */ - private void openProcessor() { - ExecutionVertex executionVertex = jobWorker.getExecutionVertex(); List outputEdges = executionVertex.getOutputEdges(); - Map> opGroupedChannelId = new HashMap<>(); - Map> opGroupedActor = new HashMap<>(); - Map opPartitionMap = new HashMap<>(); - for (int i = 0; i < outputEdges.size(); ++i) { - ExecutionEdge edge = outputEdges.get(i); - String opName = edge.getTargetExecutionJobVertexName(); - if (!opPartitionMap.containsKey(opName)) { - opGroupedChannelId.put(opName, new ArrayList<>()); - opGroupedActor.put(opName, new ArrayList<>()); - } - opGroupedChannelId.get(opName).add(executionVertex.getOutputChannelIdList().get(i)); - opGroupedActor.get(opName).add(executionVertex.getOutputActorList().get(i)); - opPartitionMap.put(opName, edge.getPartition()); + // merge all output edges to create writer + List outputChannelIds = new ArrayList<>(); + List targetActors = new ArrayList<>(); + + for (ExecutionEdge edge : outputEdges) { + String channelId = ChannelId.genIdStr( + taskId, + edge.getTargetExecutionVertex().getExecutionVertexId(), + executionVertex.getBuildTime()); + outputChannelIds.add(channelId); + targetActors.add(edge.getTargetExecutionVertex().getWorkerActor()); + } + + if (!targetActors.isEmpty()) { + DataWriter writer = new DataWriter( + outputChannelIds, targetActors, jobWorker.getWorkerConfig() + ); + + // create a collector for each output operator + Map> opGroupedChannelId = new HashMap<>(); + Map> opGroupedActor = new HashMap<>(); + Map opPartitionMap = new HashMap<>(); + for (int i = 0; i < outputEdges.size(); ++i) { + ExecutionEdge edge = outputEdges.get(i); + String opName = edge.getTargetExecutionJobVertexName(); + if (!opPartitionMap.containsKey(opName)) { + opGroupedChannelId.put(opName, new ArrayList<>()); + opGroupedActor.put(opName, new ArrayList<>()); + } + opGroupedChannelId.get(opName).add(outputChannelIds.get(i)); + opGroupedActor.get(opName).add(targetActors.get(i)); + opPartitionMap.put(opName, edge.getPartition()); + } + opPartitionMap.keySet().forEach(opName -> { + collectors.add(new OutputCollector( + writer, opGroupedChannelId.get(opName), + opGroupedActor.get(opName), opPartitionMap.get(opName) + )); + }); + } + + // consumer + List inputEdges = executionVertex.getInputEdges(); + List inputChannelIds = new ArrayList<>(); + List inputActors = new ArrayList<>(); + for (ExecutionEdge edge : inputEdges) { + String queueName = ChannelId.genIdStr( + edge.getSourceExecutionVertex().getExecutionVertexId(), + taskId, + executionVertex.getBuildTime()); + inputChannelIds.add(queueName); + inputActors.add(edge.getSourceExecutionVertex().getWorkerActor()); + } + if (!inputActors.isEmpty()) { + LOG.info("Register queue consumer, channels {}.", inputChannelIds); + reader = new DataReader(inputChannelIds, inputActors, jobWorker.getWorkerConfig()); } - opPartitionMap.keySet().forEach(opName -> { - collectors.add(new OutputCollector( - writer, opGroupedChannelId.get(opName), - opGroupedActor.get(opName), opPartitionMap.get(opName) - )); - }); RuntimeContext runtimeContext = new StreamingRuntimeContext(executionVertex, jobWorker.getWorkerConfig().configMap, executionVertex.getParallelism()); processor.open(collectors, runtimeContext); + LOG.debug("Finished preparing stream task."); } /** @@ -204,6 +135,16 @@ public abstract class StreamTask implements Runnable { */ protected abstract void init() throws Exception; + /** + * Stop running tasks. + */ + protected abstract void cancelTask() throws Exception; + + public void start() { + LOG.info("Start stream task: {}-{}", this.getClass().getSimpleName(), taskId); + this.thread.start(); + } + /** * Close running tasks. */ @@ -218,134 +159,4 @@ public abstract class StreamTask implements Runnable { LOG.info("Stream task close success."); } - // ---------------------------------------------------------------------- - // Checkpoint - // ---------------------------------------------------------------------- - - public boolean triggerCheckpoint(Long barrierId) { - throw new UnsupportedOperationException("Only source operator supports trigger checkpoints."); - } - - public void doCheckpoint(long checkpointId, Map inputPoints) { - Map outputPoints = null; - if (writer != null) { - outputPoints = writer.getOutputCheckpoints(); - RemoteCall.Barrier barrierPb = - RemoteCall.Barrier.newBuilder().setId(checkpointId).build(); - ByteBuffer byteBuffer = ByteBuffer.wrap(barrierPb.toByteArray()); - byteBuffer.order(ByteOrder.nativeOrder()); - writer.broadcastBarrier(checkpointId, byteBuffer); - } - - LOG.info("Start do checkpoint, cp id={}, inputPoints={}, outputPoints={}.", checkpointId, - inputPoints, outputPoints); - - this.lastCheckpointId = checkpointId; - Serializable processorCheckpoint = processor.saveCheckpoint(); - - try { - OperatorCheckpointInfo opCpInfo = - new OperatorCheckpointInfo(inputPoints, outputPoints, processorCheckpoint, - checkpointId); - saveCpStateAndReport(opCpInfo, checkpointId); - } catch (Exception e) { - // there will be exceptions when flush state to backend. - // we ignore the exception to prevent failover - LOG.error("Processor or op checkpoint exception.", e); - } - - LOG.info("Operator do checkpoint {} finish.", checkpointId); - } - - private void saveCpStateAndReport( - OperatorCheckpointInfo operatorCheckpointInfo, - long checkpointId) { - saveCp(operatorCheckpointInfo, checkpointId); - reportCommit(checkpointId); - - LOG.info("Finish save cp state and report, checkpoint id is {}.", checkpointId); - } - - private void saveCp(OperatorCheckpointInfo operatorCheckpointInfo, long checkpointId) { - byte[] bytes = Serializer.encode(operatorCheckpointInfo); - String cpKey = genOpCheckpointKey(checkpointId); - LOG.info("Saving task checkpoint, cpKey={}, byte len={}, checkpointInfo={}.", cpKey, - bytes.length, operatorCheckpointInfo); - synchronized (checkpointState) { - if (outdatedCheckpoints.contains(checkpointId)) { - LOG.info("Outdated checkpoint, skip save checkpoint."); - outdatedCheckpoints.remove(checkpointId); - } else { - CheckpointStateUtil.put(checkpointState, cpKey, bytes); - } - } - } - - private void reportCommit(long checkpointId) { - final JobWorkerContext context = jobWorker.getWorkerContext(); - LOG.info("Report commit async, checkpoint id {}.", checkpointId); - RemoteCallMaster.reportJobWorkerCommitAsync(context.getMaster(), - new WorkerCommitReport(context.getWorkerActorId(), checkpointId)); - } - - public void notifyCheckpointTimeout(long checkpointId) { - String cpKey = genOpCheckpointKey(checkpointId); - try { - synchronized (checkpointState) { - if (checkpointState.exists(cpKey)) { - checkpointState.remove(cpKey); - } else { - outdatedCheckpoints.add(checkpointId); - } - } - } catch (Exception e) { - LOG.error("Notify checkpoint timeout failed, checkpointId is {}.", checkpointId, e); - } - } - - public void clearExpiredCpState(long checkpointId) { - String cpKey = genOpCheckpointKey(checkpointId); - try { - checkpointState.remove(cpKey); - } catch (Exception e) { - LOG.error("Failed to remove key {} from state backend.", cpKey, e); - } - } - - public void clearExpiredQueueMsg(long checkpointId) { - // get operator checkpoint - String cpKey = genOpCheckpointKey(checkpointId); - byte[] bytes; - try { - bytes = checkpointState.get(cpKey); - } catch (Exception e) { - LOG.error("Failed to get key {} from state backend.", cpKey, e); - return; - } - if (bytes != null) { - final OperatorCheckpointInfo operatorCheckpointInfo = Serializer.decode(bytes); - long cpId = operatorCheckpointInfo.checkpointId; - if (writer != null) { - writer.clearCheckpoint(cpId); - } - } - } - - public String genOpCheckpointKey(long checkpointId) { - // TODO: need to support job restart and actorId changed - final JobWorkerContext context = jobWorker.getWorkerContext(); - return jobWorker.getWorkerConfig().checkpointConfig.jobWorkerOpCpPrefixKey() - + context.getJobName() + "_" + context.getWorkerName() + "_" + checkpointId; - } - - // ---------------------------------------------------------------------- - // Failover - // ---------------------------------------------------------------------- - protected void requestRollback(String exceptionMsg) { - jobWorker.requestRollback(exceptionMsg); - } - - public boolean isAlive() { - return this.thread.isAlive(); - } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java index 40870f51a..1bba5b0f5 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java @@ -10,12 +10,12 @@ import io.ray.streaming.runtime.worker.JobWorker; public class TwoInputStreamTask extends InputStreamTask { public TwoInputStreamTask( + int taskId, Processor processor, JobWorker jobWorker, String leftStream, - String rightStream, - long lastCheckpointId) { - super(processor, jobWorker, lastCheckpointId); + String rightStream) { + super(taskId, processor, jobWorker); ((TwoInputProcessor) (super.processor)).setLeftStream(leftStream); ((TwoInputProcessor) (super.processor)).setRightStream(rightStream); } diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java index 9af1899ac..05d4f9dc8 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java @@ -15,7 +15,7 @@ import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.core.resource.ResourceType; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.JobRuntimeContext; import io.ray.streaming.runtime.master.graphmanager.GraphManager; import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; import java.util.HashMap; @@ -34,7 +34,7 @@ public class ExecutionGraphTest extends BaseUnitTest { public void testBuildExecutionGraph() { Map jobConf = new HashMap<>(); StreamingConfig streamingConfig = new StreamingConfig(jobConf); - GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig)); + GraphManager graphManager = new GraphManagerImpl(new JobRuntimeContext(streamingConfig)); JobGraph jobGraph = buildJobGraph(); jobGraph.getJobConfig().put("streaming.task.resource.cpu.limitation.enable", "true"); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java index b7e2aef61..8473937a9 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java @@ -36,7 +36,7 @@ public class UnionStreamTest { streamSource1 .union(streamSource2, streamSource3) .sink((SinkFunction) value -> { - LOG.info("UnionStreamTest, sink: {}", value); + LOG.info("UnionStreamTest: {}", value); try { if (!Files.exists(Paths.get(sinkFileName))) { Files.createFile(Paths.get(sinkFileName)); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java index 76658e1ea..53f3cc4d1 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java @@ -14,7 +14,7 @@ public class JobMasterTest { Assert.assertNull(jobMaster.getGraphManager()); Assert.assertNull(jobMaster.getResourceManager()); Assert.assertNull(jobMaster.getJobMasterActor()); - Assert.assertFalse(jobMaster.init(false)); + Assert.assertFalse(jobMaster.init()); } } diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java index 5f3e7db35..579b1266a 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java @@ -7,7 +7,7 @@ import io.ray.streaming.runtime.BaseUnitTest; import io.ray.streaming.runtime.config.StreamingConfig; import io.ray.streaming.runtime.config.global.CommonConfig; import io.ray.streaming.runtime.core.resource.Container; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.JobRuntimeContext; import io.ray.streaming.runtime.util.RayUtils; import java.util.HashMap; import java.util.List; @@ -44,8 +44,8 @@ public class ResourceManagerTest extends BaseUnitTest { Map conf = new HashMap(); conf.put(CommonConfig.JOB_NAME, "testApi"); StreamingConfig config = new StreamingConfig(conf); - JobMasterRuntimeContext jobMasterRuntimeContext = new JobMasterRuntimeContext(config); - ResourceManager resourceManager = new ResourceManagerImpl(jobMasterRuntimeContext); + JobRuntimeContext jobRuntimeContext = new JobRuntimeContext(config); + ResourceManager resourceManager = new ResourceManagerImpl(jobRuntimeContext); // test register container List containers = resourceManager.getRegisteredContainers(); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java index 2e42e606b..4a8ea66b1 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java @@ -9,7 +9,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionGraphTest; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.resource.Container; import io.ray.streaming.runtime.core.resource.ResourceType; -import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.JobRuntimeContext; import io.ray.streaming.runtime.master.graphmanager.GraphManager; import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; import io.ray.streaming.runtime.master.resourcemanager.ResourceAssignmentView; @@ -64,7 +64,7 @@ public class PipelineFirstStrategyTest extends BaseUnitTest { Map jobConf = new HashMap<>(); StreamingConfig streamingConfig = new StreamingConfig(jobConf); - GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig)); + GraphManager graphManager = new GraphManagerImpl(new JobRuntimeContext(streamingConfig)); JobGraph jobGraph = ExecutionGraphTest.buildJobGraph(); ExecutionGraph executionGraph = ExecutionGraphTest.buildExecutionGraph(graphManager, jobGraph); ResourceAssignmentView assignmentView = strategy.assignResource(containers, executionGraph); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java index 879364e04..b1760ceb6 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java @@ -9,7 +9,7 @@ import io.ray.streaming.api.function.impl.FlatMapFunction; import io.ray.streaming.api.function.impl.ReduceFunction; import io.ray.streaming.api.stream.DataStreamSource; import io.ray.streaming.runtime.BaseUnitTest; -import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.transfer.ChannelId; import io.ray.streaming.runtime.util.EnvUtil; import io.ray.streaming.util.Config; import java.io.File; diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java index 1305267e2..4e94d5167 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java @@ -6,11 +6,11 @@ import io.ray.api.Ray; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.transfer.ChannelCreationParametersBuilder; +import io.ray.streaming.runtime.transfer.ChannelId; +import io.ray.streaming.runtime.transfer.DataMessage; import io.ray.streaming.runtime.transfer.DataReader; import io.ray.streaming.runtime.transfer.DataWriter; import io.ray.streaming.runtime.transfer.TransferHandler; -import io.ray.streaming.runtime.transfer.channel.ChannelId; -import io.ray.streaming.runtime.transfer.message.DataMessage; import io.ray.streaming.util.Config; import java.lang.management.ManagementFactory; import java.nio.ByteBuffer; @@ -104,7 +104,7 @@ class ReaderWorker extends Worker { new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"), new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessageSync", "([B)[B")); StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf); - dataReader = new DataReader(inputQueueList, inputActors, new HashMap<>(), workerConfig); + dataReader = new DataReader(inputQueueList, inputActors, workerConfig); // Should not GetBundle in RayCall thread Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() { @@ -124,7 +124,7 @@ class ReaderWorker extends Worker { int checkPointId = 1; for (int i = 0; i < msgCount * inputQueueList.size(); ++i) { - DataMessage dataMessage = (DataMessage) dataReader.read(100); + DataMessage dataMessage = dataReader.read(100); if (dataMessage == null) { LOGGER.error("dataMessage is null"); @@ -232,7 +232,7 @@ class WriterWorker extends Worker { new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessage", "([B)V"), new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessageSync", "([B)[B")); StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf); - dataWriter = new DataWriter(outputQueueList, outputActors, new HashMap<>(), workerConfig); + dataWriter = new DataWriter(outputQueueList, outputActors, workerConfig); Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() { @Override public void run() { diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java index 46270837e..11dcddeda 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java @@ -4,7 +4,6 @@ import static org.testng.Assert.assertEquals; import io.ray.streaming.runtime.BaseUnitTest; -import io.ray.streaming.runtime.transfer.channel.ChannelId; import io.ray.streaming.runtime.util.EnvUtil; import org.testng.annotations.Test; diff --git a/streaming/java/streaming-runtime/src/test/resources/log4j.properties b/streaming/java/streaming-runtime/src/test/resources/log4j.properties index 8d40bd190..30d876aec 100644 --- a/streaming/java/streaming-runtime/src/test/resources/log4j.properties +++ b/streaming/java/streaming-runtime/src/test/resources/log4j.properties @@ -3,4 +3,4 @@ log4j.rootLogger=INFO, stdout log4j.appender.stdout=org.apache.log4j.ConsoleAppender log4j.appender.stdout.Target=System.out log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SS} %-4p %c{1}:%L [%t] - %m%n +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n diff --git a/streaming/python/collector.py b/streaming/python/collector.py index 1760900fc..12b6c096b 100644 --- a/streaming/python/collector.py +++ b/streaming/python/collector.py @@ -68,7 +68,7 @@ class OutputCollector(Collector): python_buffer = self.python_serializer.serialize(record) self._writer.write( self._channel_ids[partition_index], - bytes([serialization.PYTHON_TYPE_ID]) + python_buffer) + serialization._PYTHON_TYPE_ID + python_buffer) else: # avoid repeated serialization if cross_lang_buffer is None: @@ -76,5 +76,4 @@ class OutputCollector(Collector): record) self._writer.write( self._channel_ids[partition_index], - bytes([serialization.CROSS_LANG_TYPE_ID]) + - cross_lang_buffer) + serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer) diff --git a/streaming/python/config.py b/streaming/python/config.py index b80d49b29..d7af6230e 100644 --- a/streaming/python/config.py +++ b/streaming/python/config.py @@ -8,6 +8,7 @@ class Config: NATIVE_CHANNEL = "native_channel" CHANNEL_SIZE = "channel_size" CHANNEL_SIZE_DEFAULT = 10**8 + IS_RECREATE = "streaming.is_recreate" # return from StreamingReader.getBundle if only empty message read in this # interval. TIMER_INTERVAL_MS = "timer_interval_ms" @@ -25,38 +26,3 @@ class Config: FLOW_CONTROL_TYPE = "streaming.flow_control_type" WRITER_CONSUMED_STEP = "streaming.writer.consumed_step" READER_CONSUMED_STEP = "streaming.reader.consumed_step" - - # state backend - CP_STATE_BACKEND_TYPE = "streaming.context-backend.type" - CP_STATE_BACKEND_MEMORY = "memory" - CP_STATE_BACKEND_LOCAL_FILE = "local_file" - CP_STATE_BACKEND_DEFAULT = CP_STATE_BACKEND_MEMORY - - # local disk - FILE_STATE_ROOT_PATH = "streaming.context-backend.file-state.root" - FILE_STATE_ROOT_PATH_DEFAULT = "/tmp/ray_streaming_state" - - # checkpoint - JOB_WORKER_CONTEXT_KEY = "jobworker_context_" - - # reliability level - REQUEST_ROLLBACK_RETRY_TIMES = 3 - - # checkpoint prefix key - JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY = "jobwk_op_" - - -class ConfigHelper(object): - @staticmethod - def get_cp_local_file_root_dir(conf): - value = conf.get(Config.FILE_STATE_ROOT_PATH) - if value is not None: - return value - return Config.FILE_STATE_ROOT_PATH_DEFAULT - - @staticmethod - def get_cp_context_backend_type(conf): - value = conf.get(Config.CP_STATE_BACKEND_TYPE) - if value is not None: - return value - return Config.CP_STATE_BACKEND_DEFAULT diff --git a/streaming/python/function.py b/streaming/python/function.py index 038132d90..c9cfedcc6 100644 --- a/streaming/python/function.py +++ b/streaming/python/function.py @@ -22,12 +22,6 @@ class Function(ABC): def close(self): pass - def save_checkpoint(self): - pass - - def load_checkpoint(self, checkpoint_obj): - pass - class EmptyFunction(Function): """Default function which does nothing""" @@ -64,15 +58,12 @@ class SourceFunction(Function): pass @abstractmethod - def fetch(self, ctx: SourceContext): + def run(self, ctx: SourceContext): """Starts the source. Implementations can use the :class:`SourceContext` to emit elements. """ pass - def close(self): - pass - class MapFunction(Function): """ @@ -185,29 +176,24 @@ class CollectionSourceFunction(SourceFunction): def init(self, parallel, index): pass - def fetch(self, ctx: SourceContext): + def run(self, ctx: SourceContext): for v in self.values: ctx.collect(v) - self.values = [] class LocalFileSourceFunction(SourceFunction): def __init__(self, filename): self.filename = filename - self.done = False def init(self, parallel, index): pass - def fetch(self, ctx: SourceContext): - if self.done: - return + def run(self, ctx: SourceContext): with open(self.filename, "r") as f: line = f.readline() while line != "": ctx.collect(line[:-1]) line = f.readline() - self.done = True class SimpleMapFunction(MapFunction): diff --git a/streaming/python/includes/libstreaming.pxd b/streaming/python/includes/libstreaming.pxd index 899e51694..08a1ce129 100644 --- a/streaming/python/includes/libstreaming.pxd +++ b/streaming/python/includes/libstreaming.pxd @@ -11,9 +11,6 @@ from libcpp.vector cimport vector as c_vector from libcpp.list cimport list as c_list from cpython cimport PyObject cimport cpython -from libcpp.unordered_map cimport unordered_map as c_unordered_map -from cython.operator cimport dereference, postincrement - cdef inline object PyObject_to_object(PyObject* o): # Cast to "object" increments reference count @@ -35,7 +32,7 @@ from ray.includes.unique_ids cimport ( CObjectID, ) -cdef extern from "common/status.h" namespace "ray::streaming" nogil: +cdef extern from "status.h" namespace "ray::streaming" nogil: cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus": pass cdef CStreamingStatus StatusOK "ray::streaming::StreamingStatus::OK" @@ -73,21 +70,9 @@ cdef extern from "message/message.h" namespace "ray::streaming" nogil: cdef CStreamingMessageType MessageTypeMessage "ray::streaming::StreamingMessageType::Message" cdef cppclass CStreamingMessage "ray::streaming::StreamingMessage": inline uint8_t *RawData() const - inline uint8_t *Payload() const - inline uint32_t PayloadSize() const inline uint32_t GetDataSize() const inline CStreamingMessageType GetMessageType() const - inline uint64_t GetMessageId() const - @staticmethod - inline void GetBarrierIdFromRawData(const uint8_t *data, - CStreamingBarrierHeader *barrier_header) - cdef struct CStreamingBarrierHeader "ray::streaming::StreamingBarrierHeader": - CStreamingBarrierType barrier_type; - uint64_t barrier_id; - cdef cppclass CStreamingBarrierType "ray::streaming::StreamingBarrierType": - pass - cdef uint32_t kMessageHeaderSize; - cdef uint32_t kBarrierHeaderSize; + inline uint64_t GetMessageSeqId() const cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: cdef cppclass CStreamingMessageBundleType "ray::streaming::StreamingMessageBundleType": @@ -112,40 +97,13 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums, c_list[shared_ptr[CStreamingMessage]] &msg_list); -cdef extern from "channel/channel.h" namespace "ray::streaming" nogil: +cdef extern from "channel.h" namespace "ray::streaming" nogil: cdef struct CChannelCreationParameter "ray::streaming::ChannelCreationParameter": CChannelCreationParameter() CActorID actor_id; shared_ptr[CRayFunction] async_function; shared_ptr[CRayFunction] sync_function; - cdef struct CStreamingQueueInfo "ray::streaming::StreamingQueueInfo": - uint64_t first_seq_id; - uint64_t last_message_id; - uint64_t target_message_id; - uint64_t consumed_message_id; - - cdef struct CConsumerChannelInfo "ray::streaming::ConsumerChannelInfo": - CObjectID channel_id; - uint64_t current_message_id; - uint64_t barrier_id; - uint64_t partial_barrier_id; - CStreamingQueueInfo queue_info; - uint64_t last_queue_item_delay; - uint64_t last_queue_item_latency; - uint64_t last_queue_target_diff; - uint64_t get_queue_item_times; - uint64_t notify_cnt; - CChannelCreationParameter parameter; - - cdef enum CTransferCreationStatus "ray::streaming::TransferCreationStatus": - FreshStarted = 0 - PullOk = 1 - Timeout = 2 - DataLost = 3 - Invalid = 999 - - cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil: cdef cppclass CReaderClient "ray::streaming::ReaderClient": CReaderClient() @@ -170,12 +128,11 @@ cdef extern from "data_reader.h" namespace "ray::streaming" nogil: CDataReader(shared_ptr[CRuntimeContext] &runtime_context) void Init(const c_vector[CObjectID] &input_ids, const c_vector[CChannelCreationParameter] ¶ms, + const c_vector[uint64_t] &seq_ids, const c_vector[uint64_t] &msg_ids, - c_vector[CTransferCreationStatus] &creation_status, int64_t timer_interval); CStreamingStatus GetBundle(const uint32_t timeout_ms, shared_ptr[CDataBundle] &message) - void GetOffsetInfo(c_unordered_map[CObjectID, CConsumerChannelInfo] *&offset_map); void Stop() @@ -188,9 +145,6 @@ cdef extern from "data_writer.h" namespace "ray::streaming" nogil: const c_vector[uint64_t] &queue_size_vec); long WriteMessageToBufferRing( const CObjectID &q_id, uint8_t *data, uint32_t data_size) - void BroadcastBarrier(uint64_t checkpoint_id, const uint8_t *data, uint32_t data_size) - void GetChannelOffset(c_vector[uint64_t] &result) - void ClearCheckpoint(uint64_t checkpoint_id) void Run() void Stop() diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi index 4952fd8b5..8061ad547 100644 --- a/streaming/python/includes/transfer.pxi +++ b/streaming/python/includes/transfer.pxi @@ -6,8 +6,6 @@ from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast from libcpp.string cimport string as c_string from libcpp.vector cimport vector as c_vector from libcpp.list cimport list as c_list -from libcpp.unordered_map cimport unordered_map as c_unordered_map -from cython.operator cimport dereference, postincrement from ray.includes.common cimport ( CRayFunction, @@ -40,10 +38,6 @@ from ray.streaming.includes.libstreaming cimport ( CWriterClient, CLocalMemoryBuffer, CChannelCreationParameter, - CTransferCreationStatus, - CConsumerChannelInfo, - CStreamingBarrierHeader, - kBarrierHeaderSize, ) from ray._raylet import JavaFunctionDescriptor @@ -197,7 +191,7 @@ cdef class DataWriter: self.writer = NULL def write(self, ObjectRef qid, const unsigned char[:] value): - """support zero-copy bytes, byte array, array of unsigned char""" + """support zero-copy bytes, bytearray, array of unsigned char""" cdef: CObjectID native_id = qid.data uint64_t msg_id @@ -207,25 +201,6 @@ cdef class DataWriter: msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size) return msg_id - def broadcast_barrier(self, uint64_t checkpoint_id, const unsigned char[:] value): - cdef: - uint8_t *data = (&value[0]) - uint32_t size = value.nbytes - with nogil: - self.writer.BroadcastBarrier(checkpoint_id, data, size) - - def get_output_checkpoints(self): - cdef: - c_vector[uint64_t] results - self.writer.GetChannelOffset(results) - return results - - def clear_checkpoint(self, checkpoint_id): - cdef: - uint64_t c_checkpoint_id = checkpoint_id - with nogil: - self.writer.ClearCheckpoint(c_checkpoint_id) - def stop(self): self.writer.Stop() channel_logger.info("stopped DataWriter") @@ -243,22 +218,25 @@ cdef class DataReader: @staticmethod def create(list py_input_queues, list input_creation_parameters: list[ChannelCreationParameter], + list py_seq_ids, list py_msg_ids, int64_t timer_interval, + c_bool is_recreate, bytes config_bytes, c_bool is_mock): cdef: c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues) c_vector[CChannelCreationParameter] initial_parameters + c_vector[uint64_t] seq_ids c_vector[uint64_t] msg_ids - c_vector[CTransferCreationStatus] c_creation_status CDataReader *c_reader ChannelCreationParameter parameter cdef const unsigned char[:] config_data for param in input_creation_parameters: parameter = param initial_parameters.push_back(parameter.get_parameter()) - + for py_seq_id in py_seq_ids: + seq_ids.push_back(py_seq_id) for py_msg_id in py_msg_ids: msg_ids.push_back(py_msg_id) cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() @@ -269,19 +247,11 @@ cdef class DataReader: if is_mock: ctx.get().MarkMockTest() c_reader = new CDataReader(ctx) - c_reader.Init(queue_id_vec, initial_parameters, msg_ids, c_creation_status, timer_interval) - - creation_status_map = {} - if not c_creation_status.empty(): - for i in range(queue_id_vec.size()): - k = queue_id_vec[i].Binary() - v = c_creation_status[i] - creation_status_map[k] = v - + c_reader.Init(queue_id_vec, initial_parameters, seq_ids, msg_ids, timer_interval) channel_logger.info("create native reader succeed") cdef DataReader reader = DataReader.__new__(DataReader) reader.reader = c_reader - return reader, creation_status_map + return reader def __dealloc__(self): if self.reader != NULL: @@ -295,33 +265,23 @@ cdef class DataReader: CStreamingStatus status with nogil: status = self.reader.GetBundle(timeout_millis, bundle) + cdef uint32_t bundle_type = (bundle.get().meta.get().GetBundleType()) if status != libstreaming.StatusOK: if status == libstreaming.StatusInterrupted: # avoid cyclic import import ray.streaming.runtime.transfer as transfer raise transfer.ChannelInterruptException("reader interrupted") elif status == libstreaming.StatusInitQueueFailed: - import ray.streaming.runtime.transfer as transfer - raise transfer.ChannelInitException("init channel failed") - elif status == libstreaming.StatusGetBundleTimeOut: - return [] - else: - raise Exception("no such status " + str(status)) + raise Exception("init channel failed") + elif status == libstreaming.StatusWaitQueueTimeOut: + raise Exception("wait channel object timeout") cdef: uint32_t msg_nums - CObjectID queue_id = bundle.get().c_from + CObjectID queue_id c_list[shared_ptr[CStreamingMessage]] msg_list list msgs = [] uint64_t timestamp uint64_t msg_id - c_unordered_map[CObjectID, CConsumerChannelInfo] *offset_map = NULL - shared_ptr[CStreamingMessage] barrier - CStreamingBarrierHeader barrier_header - c_unordered_map[CObjectID, CConsumerChannelInfo].iterator it - - cdef uint32_t bundle_type = (bundle.get().meta.get().GetBundleType()) - # avoid cyclic import - from ray.streaming.runtime.transfer import DataMessage if bundle_type == libstreaming.BundleTypeBundle: msg_nums = bundle.get().meta.get().GetMessageListSize() CStreamingMessageBundle.GetMessageListFromRawData( @@ -331,48 +291,16 @@ cdef class DataReader: msg_list) timestamp = bundle.get().meta.get().GetMessageBundleTs() for msg in msg_list: - msg_bytes = msg.get().Payload()[:msg.get().PayloadSize()] + msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()] qid_bytes = queue_id.Binary() - msg_id = msg.get().GetMessageId() - msgs.append( - DataMessage(msg_bytes, timestamp, msg_id, qid_bytes)) + msg_id = msg.get().GetMessageSeqId() + msgs.append((msg_bytes, msg_id, timestamp, qid_bytes)) return msgs elif bundle_type == libstreaming.BundleTypeEmpty: - timestamp = bundle.get().meta.get().GetMessageBundleTs() - msg_id = bundle.get().meta.get().GetLastMessageId() - return [DataMessage(None, timestamp, msg_id, queue_id.Binary(), True)] - elif bundle.get().meta.get().IsBarrier(): - py_offset_map = {} - self.reader.GetOffsetInfo(offset_map) - it = offset_map.begin() - while it != offset_map.end(): - queue_id_bytes = dereference(it).first.Binary() - current_message_id = dereference(it).second.current_message_id - py_offset_map[queue_id_bytes] = current_message_id - postincrement(it) - msg_nums = bundle.get().meta.get().GetMessageListSize() - CStreamingMessageBundle.GetMessageListFromRawData( - bundle.get().data + libstreaming.kMessageBundleHeaderSize, - bundle.get().data_size - libstreaming.kMessageBundleHeaderSize, - msg_nums, - msg_list) - timestamp = bundle.get().meta.get().GetMessageBundleTs() - barrier = msg_list.front() - msg_id = barrier.get().GetMessageId() - CStreamingMessage.GetBarrierIdFromRawData(barrier.get().Payload(), &barrier_header) - barrier_id = barrier_header.barrier_id - barrier_data = (barrier.get().Payload() + kBarrierHeaderSize)[ - :barrier.get().PayloadSize() - kBarrierHeaderSize] - barrier_type = barrier_header.barrier_type - py_queue_id = queue_id.Binary() - from ray.streaming.runtime.transfer import CheckpointBarrier - return [CheckpointBarrier( - barrier_data, timestamp, msg_id, py_queue_id, py_offset_map, - barrier_id, barrier_type)] + return [] else: raise Exception("Unsupported bundle type {}".format(bundle_type)) - def stop(self): self.reader.Stop() channel_logger.info("stopped DataReader") diff --git a/streaming/python/operator.py b/streaming/python/operator.py index 9163519d6..4952b2d00 100644 --- a/streaming/python/operator.py +++ b/streaming/python/operator.py @@ -3,11 +3,10 @@ import importlib import logging from abc import ABC, abstractmethod +from ray import streaming from ray.streaming import function from ray.streaming import message from ray.streaming.collector import Collector -from ray.streaming.collector import CollectionCollector -from ray.streaming.function import SourceFunction from ray.streaming.runtime import gateway_client logger = logging.getLogger(__name__) @@ -41,14 +40,6 @@ class Operator(ABC): def operator_type(self) -> OperatorType: pass - @abstractmethod - def save_checkpoint(self): - pass - - @abstractmethod - def load_checkpoint(self, checkpoint_obj): - pass - class OneInputOperator(Operator, ABC): """Interface for stream operators with one input.""" @@ -99,20 +90,8 @@ class StreamOperator(Operator, ABC): for collector in self.collectors: collector.collect(record) - def save_checkpoint(self): - self.func.save_checkpoint() - def load_checkpoint(self, checkpoint_obj): - self.func.load_checkpoint(checkpoint_obj) - - -class SourceOperator(Operator, ABC): - @abstractmethod - def fetch(self): - pass - - -class SourceOperatorImpl(SourceOperator, StreamOperator): +class SourceOperator(StreamOperator): """ Operator to run a :class:`function.SourceFunction` """ @@ -125,19 +104,19 @@ class SourceOperatorImpl(SourceOperator, StreamOperator): for collector in self.collectors: collector.collect(message.Record(value)) - def __init__(self, func: SourceFunction): + def __init__(self, func): assert isinstance(func, function.SourceFunction) super().__init__(func) self.source_context = None def open(self, collectors, runtime_context): super().open(collectors, runtime_context) - self.source_context = SourceOperatorImpl.SourceContextImpl(collectors) + self.source_context = SourceOperator.SourceContextImpl(collectors) self.func.init(runtime_context.get_parallelism(), runtime_context.get_task_index()) - def fetch(self): - self.func.fetch(self.source_context) + def run(self): + self.func.run(self.source_context) def operator_type(self): return OperatorType.SOURCE @@ -168,7 +147,8 @@ class FlatMapOperator(StreamOperator, OneInputOperator): def open(self, collectors, runtime_context): super().open(collectors, runtime_context) - self.collection_collector = CollectionCollector(collectors) + self.collection_collector = streaming.collector.CollectionCollector( + collectors) def process_element(self, record): self.func.flat_map(record.value, self.collection_collector) @@ -306,12 +286,12 @@ class ChainedOperator(StreamOperator, ABC): raise Exception("Current operator type is not supported") -class ChainedSourceOperator(SourceOperator, ChainedOperator): +class ChainedSourceOperator(ChainedOperator): def __init__(self, operators, configs): super().__init__(operators, configs) - def fetch(self): - self.operators[0].fetch() + def run(self): + self.operators[0].run() class ChainedOneInputOperator(ChainedOperator): @@ -370,7 +350,7 @@ def load_operator(descriptor_operator_bytes: bytes): _function_to_operator = { - function.SourceFunction: SourceOperatorImpl, + function.SourceFunction: SourceOperator, function.MapFunction: MapOperator, function.FlatMapFunction: FlatMapOperator, function.FilterFunction: FilterOperator, diff --git a/streaming/python/runtime/command.py b/streaming/python/runtime/command.py deleted file mode 100644 index cc5f02e1f..000000000 --- a/streaming/python/runtime/command.py +++ /dev/null @@ -1,30 +0,0 @@ -class BaseWorkerCmd: - """ - base worker cmd - """ - - def __init__(self, actor_id): - self.from_actor_id = actor_id - - -class WorkerCommitReport(BaseWorkerCmd): - """ - worker commit report - """ - - def __init__(self, actor_id, commit_checkpoint_id): - super().__init__(actor_id) - self.commit_checkpoint_id = commit_checkpoint_id - - -class WorkerRollbackRequest(BaseWorkerCmd): - """ - worker rollback request - """ - - def __init__(self, actor_id, exception_msg): - super().__init__(actor_id) - self.__exception_msg = exception_msg - - def exception_msg(self): - return self.__exception_msg diff --git a/streaming/python/runtime/context_backend.py b/streaming/python/runtime/context_backend.py deleted file mode 100644 index 65e811cfe..000000000 --- a/streaming/python/runtime/context_backend.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -import os -from abc import ABC, abstractmethod -from os import path - -from ray.streaming.config import ConfigHelper, Config - -logger = logging.getLogger(__name__) - - -class ContextBackend(ABC): - @abstractmethod - def get(self, key): - pass - - @abstractmethod - def put(self, key, value): - pass - - @abstractmethod - def remove(self, key): - pass - - -class MemoryContextBackend(ContextBackend): - def __init__(self, conf): - self.__dic = dict() - - def get(self, key): - return self.__dic.get(key) - - def put(self, key, value): - self.__dic[key] = value - - def remove(self, key): - if key in self.__dic: - del self.__dic[key] - - -class LocalFileContextBackend(ContextBackend): - def __init__(self, conf): - self.__dir = ConfigHelper.get_cp_local_file_root_dir(conf) - logger.info("Start init local file state backend, root_dir={}.".format( - self.__dir)) - try: - os.mkdir(self.__dir) - except FileExistsError: - logger.info("dir already exists, skipped.") - - def put(self, key, value): - logger.info("Put value of key {} start.".format(key)) - with open(self.__gen_file_path(key), "wb") as f: - f.write(value) - - def get(self, key): - logger.info("Get value of key {} start.".format(key)) - full_path = self.__gen_file_path(key) - if not os.path.isfile(full_path): - return None - with open(full_path, "rb") as f: - return f.read() - - def remove(self, key): - logger.info("Remove value of key {} start.".format(key)) - try: - os.remove(self.__gen_file_path(key)) - except Exception: - # ignore exception - pass - - def rename(self, src, dst): - logger.info("rename {} to {}".format(src, dst)) - os.rename(self.__gen_file_path(src), self.__gen_file_path(dst)) - - def exists(self, key) -> bool: - return os.path.exists(key) - - def __gen_file_path(self, key): - return path.join(self.__dir, key) - - -class AtomicFsContextBackend(LocalFileContextBackend): - def __init__(self, conf): - super().__init__(conf) - self.__tmp_flag = "_tmp" - - def put(self, key, value): - tmp_key = key + self.__tmp_flag - if super().exists(tmp_key) and not super().exists(key): - super().rename(tmp_key, key) - super().put(tmp_key, value) - super().remove(key) - super().rename(tmp_key, key) - - def get(self, key): - tmp_key = key + self.__tmp_flag - if super().exists(tmp_key) and not super().exists(key): - return super().get(tmp_key) - return super().get(key) - - def remove(self, key): - tmp_key = key + self.__tmp_flag - if super().exists(tmp_key): - super().remove(tmp_key) - super().remove(key) - - -class ContextBackendFactory: - @staticmethod - def get_context_backend(worker_config) -> ContextBackend: - backend_type = ConfigHelper.get_cp_context_backend_type(worker_config) - context_backend = None - if backend_type == Config.CP_STATE_BACKEND_LOCAL_FILE: - context_backend = AtomicFsContextBackend(worker_config) - elif backend_type == Config.CP_STATE_BACKEND_MEMORY: - context_backend = MemoryContextBackend(worker_config) - return context_backend diff --git a/streaming/python/runtime/failover.py b/streaming/python/runtime/failover.py deleted file mode 100644 index 702cdbab3..000000000 --- a/streaming/python/runtime/failover.py +++ /dev/null @@ -1,30 +0,0 @@ -class Barrier: - """ - barrier - """ - - def __init__(self, id): - self.id = id - - def __str__(self): - return "Barrier [id:%s]" % self.id - - -class OpCheckpointInfo: - """ - operator checkpoint info - """ - - def __init__(self, - operator_point=None, - input_points=None, - output_points=None, - checkpoint_id=None): - if input_points is None: - input_points = {} - if output_points is None: - output_points = {} - self.operator_point = operator_point - self.input_points = input_points - self.output_points = output_points - self.checkpoint_id = checkpoint_id diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py index d2719ee12..6db9cf39e 100644 --- a/streaming/python/runtime/graph.py +++ b/streaming/python/runtime/graph.py @@ -5,9 +5,6 @@ import ray import ray.streaming.generated.remote_call_pb2 as remote_call_pb import ray.streaming.operator as operator import ray.streaming.partition as partition -from ray._raylet import ActorID -from ray.actor import ActorHandle -from ray.streaming.config import Config from ray.streaming.generated.streaming_pb2 import Language logger = logging.getLogger(__name__) @@ -30,12 +27,10 @@ class NodeType(enum.Enum): class ExecutionEdge: - def __init__(self, execution_edge_pb, language): - self.source_execution_vertex_id = execution_edge_pb \ - .source_execution_vertex_id - self.target_execution_vertex_id = execution_edge_pb \ - .target_execution_vertex_id - partition_bytes = execution_edge_pb.partition + def __init__(self, edge_pb, language): + self.source_execution_vertex_id = edge_pb.source_execution_vertex_id + self.target_execution_vertex_id = edge_pb.target_execution_vertex_id + partition_bytes = edge_pb.partition # Sink node doesn't have partition function, # so we only deserialize partition_bytes when it's not None or empty if language == Language.PYTHON and partition_bytes: @@ -43,73 +38,50 @@ class ExecutionEdge: class ExecutionVertex: - worker_actor: ActorHandle - - def __init__(self, execution_vertex_pb): - self.execution_vertex_id = execution_vertex_pb.execution_vertex_id - self.execution_job_vertex_id = execution_vertex_pb \ - .execution_job_vertex_id - self.execution_job_vertex_name = execution_vertex_pb \ - .execution_job_vertex_name - self.execution_vertex_index = execution_vertex_pb\ - .execution_vertex_index - self.parallelism = execution_vertex_pb.parallelism - if execution_vertex_pb\ - .language == Language.PYTHON: - # python operator descriptor - operator_bytes = execution_vertex_pb.operator - if execution_vertex_pb.chained: + def __init__(self, vertex_pb): + self.execution_vertex_id = vertex_pb.execution_vertex_id + self.execution_job_vertex_Id = vertex_pb.execution_job_vertex_Id + self.execution_job_vertex_name = vertex_pb.execution_job_vertex_name + self.execution_vertex_index = vertex_pb.execution_vertex_index + self.parallelism = vertex_pb.parallelism + if vertex_pb.language == Language.PYTHON: + operator_bytes = vertex_pb.operator # python operator descriptor + if vertex_pb.chained: logger.info("Load chained operator") self.stream_operator = operator.load_chained_operator( operator_bytes) else: logger.info("Load operator") self.stream_operator = operator.load_operator(operator_bytes) - self.worker_actor = None - if execution_vertex_pb.worker_actor: - self.worker_actor = ray.actor.ActorHandle. \ - _deserialization_helper(execution_vertex_pb.worker_actor) - self.container_id = execution_vertex_pb.container_id - self.build_time = execution_vertex_pb.build_time - self.language = execution_vertex_pb.language - self.config = execution_vertex_pb.config - self.resource = execution_vertex_pb.resource - - @property - def execution_vertex_name(self): - return "{}_{}_{}".format(self.execution_job_vertex_id, - self.execution_job_vertex_name, - self.execution_vertex_id) + self.worker_actor = ray.actor.ActorHandle. \ + _deserialization_helper(vertex_pb.worker_actor) + self.container_id = vertex_pb.container_id + self.build_time = vertex_pb.build_time + self.language = vertex_pb.language + self.config = vertex_pb.config + self.resource = vertex_pb.resource class ExecutionVertexContext: - actor_id: ActorID - execution_vertex: ExecutionVertex - - def __init__( - self, - execution_vertex_context_pb: remote_call_pb.ExecutionVertexContext - ): - self.execution_vertex = ExecutionVertex( - execution_vertex_context_pb.current_execution_vertex) - self.job_name = self.execution_vertex.config[Config.STREAMING_JOB_NAME] - self.exe_vertex_name = self.execution_vertex.execution_vertex_name - self.actor_id = self.execution_vertex.worker_actor._ray_actor_id + def __init__(self, + vertex_context_pb: remote_call_pb.ExecutionVertexContext): + self.execution_vertex = \ + ExecutionVertex(vertex_context_pb.current_execution_vertex) self.upstream_execution_vertices = [ - ExecutionVertex(vertex) for vertex in - execution_vertex_context_pb.upstream_execution_vertices + ExecutionVertex(vertex) + for vertex in vertex_context_pb.upstream_execution_vertices ] self.downstream_execution_vertices = [ - ExecutionVertex(vertex) for vertex in - execution_vertex_context_pb.downstream_execution_vertices + ExecutionVertex(vertex) + for vertex in vertex_context_pb.downstream_execution_vertices ] self.input_execution_edges = [ ExecutionEdge(edge, self.execution_vertex.language) - for edge in execution_vertex_context_pb.input_execution_edges + for edge in vertex_context_pb.input_execution_edges ] self.output_execution_edges = [ ExecutionEdge(edge, self.execution_vertex.language) - for edge in execution_vertex_context_pb.output_execution_edges + for edge in vertex_context_pb.output_execution_edges ] def get_parallelism(self): @@ -140,16 +112,16 @@ class ExecutionVertexContext: def get_task_id(self): return self.execution_vertex.execution_vertex_id - def get_source_actor_by_execution_vertex_id(self, execution_vertex_id): - for execution_vertex in self.upstream_execution_vertices: - if execution_vertex.execution_vertex_id == execution_vertex_id: - return execution_vertex.worker_actor - raise Exception( - "Vertex %s does not exist!".format(execution_vertex_id)) + def get_source_actor_by_vertex_id(self, execution_vertex_id): + for vertex in self.upstream_execution_vertices: + if vertex.execution_vertex_id == execution_vertex_id: + return vertex.worker_actor + raise Exception("ExecutionVertex %s does not exist!" + .format(execution_vertex_id)) - def get_target_actor_by_execution_vertex_id(self, execution_vertex_id): - for execution_vertex in self.downstream_execution_vertices: - if execution_vertex.execution_vertex_id == execution_vertex_id: - return execution_vertex.worker_actor - raise Exception( - "Vertex %s does not exist!".format(execution_vertex_id)) + def get_target_actor_by_vertex_id(self, execution_vertex_id): + for vertex in self.downstream_execution_vertices: + if vertex.execution_vertex_id == execution_vertex_id: + return vertex.worker_actor + raise Exception("ExecutionVertex %s does not exist!" + .format(execution_vertex_id)) diff --git a/streaming/python/runtime/processor.py b/streaming/python/runtime/processor.py index 1083713ee..ccfa55921 100644 --- a/streaming/python/runtime/processor.py +++ b/streaming/python/runtime/processor.py @@ -23,14 +23,6 @@ class Processor(ABC): def close(self): pass - @abstractmethod - def save_checkpoint(self): - pass - - @abstractmethod - def load_checkpoint(self, checkpoint_obj): - pass - class StreamingProcessor(Processor, ABC): """StreamingProcessor is a process unit for a operator.""" @@ -48,13 +40,7 @@ class StreamingProcessor(Processor, ABC): logger.info("Opened Processor {}".format(self)) def close(self): - self.operator.close() - - def save_checkpoint(self): - self.operator.save_checkpoint() - - def load_checkpoint(self, checkpoint_obj): - self.operator.load_checkpoint(checkpoint_obj) + pass class SourceProcessor(StreamingProcessor): @@ -66,8 +52,8 @@ class SourceProcessor(StreamingProcessor): def process(self, record): raise Exception("SourceProcessor should not process record") - def fetch(self): - self.operator.fetch() + def run(self): + self.operator.run() class OneInputProcessor(StreamingProcessor): diff --git a/streaming/python/runtime/remote_call.py b/streaming/python/runtime/remote_call.py deleted file mode 100644 index 4f5f082ee..000000000 --- a/streaming/python/runtime/remote_call.py +++ /dev/null @@ -1,95 +0,0 @@ -import logging -import os -import ray -import time -from enum import Enum - -from ray.actor import ActorHandle -from ray.streaming.generated import remote_call_pb2 -from ray.streaming.runtime.command\ - import WorkerCommitReport, WorkerRollbackRequest - -logger = logging.getLogger(__name__) - - -class CallResult: - """ - Call Result - """ - - def __init__(self, success, result_code, result_msg, result_obj): - self.success = success - self.result_code = result_code - self.result_msg = result_msg - self.result_obj = result_obj - - @staticmethod - def success(payload=None): - return CallResult(True, CallResultEnum.SUCCESS, None, payload) - - @staticmethod - def fail(payload=None): - return CallResult(False, CallResultEnum.FAILED, None, payload) - - @staticmethod - def skipped(msg=None): - return CallResult(True, CallResultEnum.SKIPPED, msg, None) - - def is_success(self): - if self.result_code is CallResultEnum.SUCCESS: - return True - - return False - - -class CallResultEnum(Enum): - """ - call result enum - """ - - SUCCESS = 0 - FAILED = 1 - SKIPPED = 2 - - -class RemoteCallMst: - """ - remote call job master - """ - - @staticmethod - def request_job_worker_rollback(master: ActorHandle, - request: WorkerRollbackRequest): - logger.info("Remote call mst: request job worker rollback start.") - request_pb = remote_call_pb2.BaseWorkerCmd() - request_pb.actor_id = request.from_actor_id - request_pb.timestamp = int(time.time() * 1000.0) - rollback_request_pb = remote_call_pb2.WorkerRollbackRequest() - rollback_request_pb.exception_msg = request.exception_msg() - rollback_request_pb.worker_hostname = os.uname()[1] - rollback_request_pb.worker_pid = str(os.getpid()) - request_pb.detail.Pack(rollback_request_pb) - return_ids = master.requestJobWorkerRollback\ - .remote(request_pb.SerializeToString()) - result = remote_call_pb2.BoolResult() - result.ParseFromString(ray.get(return_ids)) - logger.info("Remote call mst: request job worker rollback finish.") - return result.boolRes - - @staticmethod - def report_job_worker_commit(master: ActorHandle, - report: WorkerCommitReport): - logger.info("Remote call mst: report job worker commit start.") - report_pb = remote_call_pb2.BaseWorkerCmd() - - report_pb.actor_id = report.from_actor_id - report_pb.timestamp = int(time.time() * 1000.0) - wk_commit = remote_call_pb2.WorkerCommitReport() - wk_commit.commit_checkpoint_id = report.commit_checkpoint_id - report_pb.detail.Pack(wk_commit) - return_id = master.reportJobWorkerCommit\ - .remote(report_pb.SerializeToString()) - result = remote_call_pb2.BoolResult() - result.ParseFromString(ray.get(return_id)) - logger.info("Remote call mst: report job worker commit finish.") - return result.boolRes diff --git a/streaming/python/runtime/serialization.py b/streaming/python/runtime/serialization.py index 600e1084c..2d038e482 100644 --- a/streaming/python/runtime/serialization.py +++ b/streaming/python/runtime/serialization.py @@ -3,11 +3,11 @@ import pickle import msgpack from ray.streaming import message -RECORD_TYPE_ID = 0 -KEY_RECORD_TYPE_ID = 1 -CROSS_LANG_TYPE_ID = 0 -JAVA_TYPE_ID = 1 -PYTHON_TYPE_ID = 2 +_RECORD_TYPE_ID = 0 +_KEY_RECORD_TYPE_ID = 1 +_CROSS_LANG_TYPE_ID = b"0" +_JAVA_TYPE_ID = b"1" +_PYTHON_TYPE_ID = b"2" class Serializer(ABC): @@ -33,21 +33,21 @@ class CrossLangSerializer(Serializer): def serialize(self, obj): if type(obj) is message.Record: - fields = [RECORD_TYPE_ID, obj.stream, obj.value] + fields = [_RECORD_TYPE_ID, obj.stream, obj.value] elif type(obj) is message.KeyRecord: - fields = [KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value] + fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value] else: raise Exception("Unsupported value {}".format(obj)) return msgpack.packb(fields, use_bin_type=True) def deserialize(self, data): - fields = msgpack.unpackb(data, raw=False) - if fields[0] == RECORD_TYPE_ID: + fields = msgpack.unpackb(data, raw=False, strict_map_key=False) + if fields[0] == _RECORD_TYPE_ID: stream, value = fields[1:] record = message.Record(value) record.stream = stream return record - elif fields[0] == KEY_RECORD_TYPE_ID: + elif fields[0] == _KEY_RECORD_TYPE_ID: stream, key, value = fields[1:] key_record = message.KeyRecord(key, value) key_record.stream = stream diff --git a/streaming/python/runtime/task.py b/streaming/python/runtime/task.py index 54ec3cf38..713a29703 100644 --- a/streaming/python/runtime/task.py +++ b/streaming/python/runtime/task.py @@ -1,30 +1,14 @@ import logging -import pickle import threading -import time -import typing from abc import ABC, abstractmethod -from typing import Optional from ray.streaming.collector import OutputCollector from ray.streaming.config import Config from ray.streaming.context import RuntimeContextImpl -from ray.streaming.generated import remote_call_pb2 from ray.streaming.runtime import serialization -from ray.streaming.runtime.command import WorkerCommitReport -from ray.streaming.runtime.failover import Barrier, OpCheckpointInfo -from ray.streaming.runtime.remote_call import RemoteCallMst from ray.streaming.runtime.serialization import \ PythonSerializer, CrossLangSerializer -from ray.streaming.runtime.transfer import CheckpointBarrier -from ray.streaming.runtime.transfer import DataMessage from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader -from ray.streaming.runtime.transfer import ChannelRecoverInfo -from ray.streaming.runtime.transfer import ChannelInterruptException - -if typing.TYPE_CHECKING: - from ray.streaming.runtime.worker import JobWorker - from ray.streaming.runtime.processor import Processor, SourceProcessor logger = logging.getLogger(__name__) @@ -32,85 +16,18 @@ logger = logging.getLogger(__name__) class StreamTask(ABC): """Base class for all streaming tasks. Each task runs a processor.""" - def __init__(self, task_id: int, processor: "Processor", - worker: "JobWorker", last_checkpoint_id: int): - self.worker_context = worker.worker_context - self.vertex_context = worker.execution_vertex_context + def __init__(self, task_id, processor, worker): self.task_id = task_id self.processor = processor self.worker = worker - self.config: dict = worker.config - self.reader: Optional[DataReader] = None - self.writer: Optional[DataWriter] = None - self.is_initial_state = True - self.last_checkpoint_id: int = last_checkpoint_id + self.config = worker.config + self.reader = None # DataReader + self.writers = {} # ExecutionEdge -> DataWriter + self.thread = None + self.prepare_task() self.thread = threading.Thread(target=self.run, daemon=True) - def do_checkpoint(self, checkpoint_id: int, input_points): - logger.info("Start do checkpoint, cp id {}, inputPoints {}.".format( - checkpoint_id, input_points)) - - output_points = None - if self.writer is not None: - output_points = self.writer.get_output_checkpoints() - - operator_checkpoint = self.processor.save_checkpoint() - op_checkpoint_info = OpCheckpointInfo( - operator_checkpoint, input_points, output_points, checkpoint_id) - self.__save_cp_state_and_report(op_checkpoint_info, checkpoint_id) - - barrier_pb = remote_call_pb2.Barrier() - barrier_pb.id = checkpoint_id - byte_buffer = barrier_pb.SerializeToString() - if self.writer is not None: - self.writer.broadcast_barrier(checkpoint_id, byte_buffer) - logger.info("Operator checkpoint {} finish.".format(checkpoint_id)) - - def __save_cp_state_and_report(self, op_checkpoint_info, checkpoint_id): - logger.info( - "Start to save cp state and report, checkpoint id is {}.".format( - checkpoint_id)) - self.__save_cp(op_checkpoint_info, checkpoint_id) - self.__report_commit(checkpoint_id) - self.last_checkpoint_id = checkpoint_id - - def __save_cp(self, op_checkpoint_info, checkpoint_id): - logger.info("save operator cp, op_checkpoint_info={}".format( - op_checkpoint_info)) - cp_bytes = pickle.dumps(op_checkpoint_info) - self.worker.context_backend.put( - self.__gen_op_checkpoint_key(checkpoint_id), cp_bytes) - - def __report_commit(self, checkpoint_id: int): - logger.info("Report commit, checkpoint id {}.".format(checkpoint_id)) - report = WorkerCommitReport(self.vertex_context.actor_id.binary(), - checkpoint_id) - RemoteCallMst.report_job_worker_commit(self.worker.master_actor, - report) - - def clear_expired_cp_state(self, checkpoint_id): - cp_key = self.__gen_op_checkpoint_key(checkpoint_id) - self.worker.context_backend.remove(cp_key) - - def clear_expired_queue_msg(self, checkpoint_id): - # clear operator checkpoint - if self.writer is not None: - self.writer.clear_checkpoint(checkpoint_id) - - def request_rollback(self, exception_msg: str): - self.worker.request_rollback(exception_msg) - - def __gen_op_checkpoint_key(self, checkpoint_id): - op_checkpoint_key = Config.JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY + str( - self.vertex_context.job_name) + "_" + str( - self.vertex_context.exe_vertex_name) + "_" + str(checkpoint_id) - logger.info( - "Generate op checkpoint key {}. ".format(op_checkpoint_key)) - return op_checkpoint_key - - def prepare_task(self, is_recreate: bool): - logger.info( - "Preparing stream task, is_recreate={}.".format(is_recreate)) + def prepare_task(self): channel_conf = dict(self.worker.config) channel_size = int( self.worker.config.get(Config.CHANNEL_SIZE, @@ -122,76 +39,45 @@ class StreamTask(ABC): execution_vertex_context = self.worker.execution_vertex_context build_time = execution_vertex_context.build_time - # when use memory state, if actor throw exception, will miss state - op_checkpoint_info = OpCheckpointInfo() - - cp_bytes = None - # get operator checkpoint - if is_recreate: - cp_key = self.__gen_op_checkpoint_key(self.last_checkpoint_id) - logger.info("Getting task checkpoints from state, " - "cpKey={}, checkpointId={}.".format( - cp_key, self.last_checkpoint_id)) - cp_bytes = self.worker.context_backend.get(cp_key) - if cp_bytes is None: - msg = "Task recover failed, checkpoint is null!"\ - "cpKey={}".format(cp_key) - raise RuntimeError(msg) - - if cp_bytes is not None: - op_checkpoint_info = pickle.loads(cp_bytes) - self.processor.load_checkpoint(op_checkpoint_info.operator_point) - logger.info("Stream task recover from checkpoint state," - "checkpoint bytes len={}, checkpointInfo={}.".format( - cp_bytes.__len__(), op_checkpoint_info)) - # writers collectors = [] output_actors_map = {} for edge in execution_vertex_context.output_execution_edges: target_task_id = edge.target_execution_vertex_id - target_actor = execution_vertex_context \ - .get_target_actor_by_execution_vertex_id(target_task_id) + target_actor = execution_vertex_context\ + .get_target_actor_by_vertex_id(target_task_id) channel_name = ChannelID.gen_id(self.task_id, target_task_id, build_time) output_actors_map[channel_name] = target_actor - if len(output_actors_map) > 0: - channel_str_ids = list(output_actors_map.keys()) - target_actors = list(output_actors_map.values()) - logger.info("Create DataWriter channel_ids {}," - "target_actors {}, output_points={}.".format( - channel_str_ids, target_actors, - op_checkpoint_info.output_points)) - self.writer = DataWriter(channel_str_ids, target_actors, - channel_conf) - logger.info("Create DataWriter succeed channel_ids {}, " - "target_actors {}.".format(channel_str_ids, - target_actors)) - for edge in execution_vertex_context.output_execution_edges: + if len(output_actors_map) > 0: + channel_ids = list(output_actors_map.keys()) + target_actors = list(output_actors_map.values()) + logger.info( + "Create DataWriter channel_ids {}, target_actors {}." + .format(channel_ids, target_actors)) + writer = DataWriter(channel_ids, target_actors, channel_conf) + self.writers[edge] = writer collectors.append( - OutputCollector(self.writer, channel_str_ids, - target_actors, edge.partition)) + OutputCollector(writer, channel_ids, target_actors, + edge.partition)) # readers input_actor_map = {} for edge in execution_vertex_context.input_execution_edges: source_task_id = edge.source_execution_vertex_id - source_actor = execution_vertex_context \ - .get_source_actor_by_execution_vertex_id(source_task_id) + source_actor = execution_vertex_context\ + .get_source_actor_by_vertex_id(source_task_id) channel_name = ChannelID.gen_id(source_task_id, self.task_id, build_time) input_actor_map[channel_name] = source_actor if len(input_actor_map) > 0: - channel_str_ids = list(input_actor_map.keys()) + channel_ids = list(input_actor_map.keys()) from_actors = list(input_actor_map.values()) - logger.info("Create DataReader, channels {}," - "input_actors {}, input_points={}.".format( - channel_str_ids, from_actors, - op_checkpoint_info.input_points)) - self.reader = DataReader(channel_str_ids, from_actors, - channel_conf) + logger.info("Create DataReader, channels {}, input_actors {}." + .format(channel_ids, from_actors)) + self.reader = DataReader(channel_ids, from_actors, channel_conf) def exit_handler(): # Make DataReader stop read data when MockQueue destructor @@ -201,31 +87,21 @@ class StreamTask(ABC): import atexit atexit.register(exit_handler) + # TODO(chaokunyang) add task/job config runtime_context = RuntimeContextImpl( self.worker.task_id, execution_vertex_context.execution_vertex.execution_vertex_index, - execution_vertex_context.get_parallelism(), - config=channel_conf, - job_config=channel_conf) + execution_vertex_context.get_parallelism()) logger.info("open Processor {}".format(self.processor)) self.processor.open(collectors, runtime_context) - # immediately save cp. In case of FO in cp 0 - # or use old cp in multi node FO. - self.__save_cp(op_checkpoint_info, self.last_checkpoint_id) - - def recover(self, is_recreate: bool): - self.prepare_task(is_recreate) - - recover_info = ChannelRecoverInfo() - if self.reader is not None: - recover_info = self.reader.get_channel_recover_info() + @abstractmethod + def init(self): + pass + def start(self): self.thread.start() - logger.info("Start operator success.") - return recover_info - @abstractmethod def run(self): pass @@ -234,24 +110,14 @@ class StreamTask(ABC): def cancel_task(self): pass - @abstractmethod - def commit_trigger(self, barrier: Barrier) -> bool: - pass - class InputStreamTask(StreamTask): """Base class for stream tasks that execute a :class:`runtime.processor.OneInputProcessor` or :class:`runtime.processor.TwoInputProcessor` """ - def commit_trigger(self, barrier): - raise RuntimeError( - "commit_trigger is only supported in SourceStreamTask.") - - def __init__(self, task_id, processor_instance, worker, - last_checkpoint_id): - super().__init__(task_id, processor_instance, worker, - last_checkpoint_id) + def __init__(self, task_id, processor_instance, worker): + super().__init__(task_id, processor_instance, worker) self.running = True self.stopped = False self.read_timeout_millis = \ @@ -260,58 +126,25 @@ class InputStreamTask(StreamTask): self.python_serializer = PythonSerializer() self.cross_lang_serializer = CrossLangSerializer() + def init(self): + pass + def run(self): - logger.info("Input task thread start.") - try: - while self.running: - self.worker.initial_state_lock.acquire() - try: - item = self.reader.read(self.read_timeout_millis) - self.is_initial_state = False - finally: - self.worker.initial_state_lock.release() - - if item is None: - continue - - if isinstance(item, DataMessage): - msg_data = item.body - type_id = msg_data[0] - if type_id == serialization.PYTHON_TYPE_ID: - msg = self.python_serializer.deserialize(msg_data[1:]) - else: - msg = self.cross_lang_serializer.deserialize( - msg_data[1:]) - self.processor.process(msg) - elif isinstance(item, CheckpointBarrier): - logger.info("Got barrier:{}".format(item)) - logger.info("Start to do checkpoint {}.".format( - item.checkpoint_id)) - - input_points = item.get_input_checkpoints() - - self.do_checkpoint(item.checkpoint_id, input_points) - logger.info("Do checkpoint {} success.".format( - item.checkpoint_id)) + while self.running: + item = self.reader.read(self.read_timeout_millis) + if item is not None: + msg_data = item.body() + type_id = msg_data[:1] + if (type_id == serialization._PYTHON_TYPE_ID): + msg = self.python_serializer.deserialize(msg_data[1:]) else: - raise RuntimeError( - "Unknown item type! item={}".format(item)) - - except ChannelInterruptException: - logger.info("queue has stopped.") - except BaseException as e: - logger.exception( - "Last success checkpointId={}, now occur error.".format( - self.last_checkpoint_id)) - self.request_rollback(str(e)) - - logger.info("Source fetcher thread exit.") + msg = self.cross_lang_serializer.deserialize(msg_data[1:]) + self.processor.process(msg) self.stopped = True def cancel_task(self): self.running = False while not self.stopped: - time.sleep(0.5) pass @@ -319,64 +152,22 @@ class OneInputStreamTask(InputStreamTask): """A stream task for executing :class:`runtime.processor.OneInputProcessor` """ - def __init__(self, task_id, processor_instance, worker, - last_checkpoint_id): - super().__init__(task_id, processor_instance, worker, - last_checkpoint_id) + def __init__(self, task_id, processor_instance, worker): + super().__init__(task_id, processor_instance, worker) class SourceStreamTask(StreamTask): """A stream task for executing :class:`runtime.processor.SourceProcessor` """ - processor: "SourceProcessor" - def __init__(self, task_id: int, processor_instance: "SourceProcessor", - worker: "JobWorker", last_checkpoint_id): - super().__init__(task_id, processor_instance, worker, - last_checkpoint_id) - self.running = True - self.stopped = False - self.__pending_barrier: Optional[Barrier] = None + def __init__(self, task_id, processor_instance, worker): + super().__init__(task_id, processor_instance, worker) + + def init(self): + pass def run(self): - logger.info("Source task thread start.") - try: - while self.running: - self.processor.fetch() - # check checkpoint - if self.__pending_barrier is not None: - # source fetcher only have outputPoints - barrier = self.__pending_barrier - logger.info("Start to do checkpoint {}.".format( - barrier.id)) - self.do_checkpoint(barrier.id, barrier) - logger.info("Finish to do checkpoint {}.".format( - barrier.id)) - self.__pending_barrier = None - - except ChannelInterruptException: - logger.info("queue has stopped.") - except Exception as e: - logger.exception( - "Last success checkpointId={}, now occur error.".format( - self.last_checkpoint_id)) - self.request_rollback(str(e)) - - logger.info("Source fetcher thread exit.") - self.stopped = True - - def commit_trigger(self, barrier): - if self.__pending_barrier is not None: - logger.warning( - "Last barrier is not broadcast now, skip this barrier trigger." - ) - return False - - self.__pending_barrier = barrier - return True + self.processor.run() def cancel_task(self): - self.running = False - while not self.stopped: - time.sleep(0.5) - pass + pass diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py index 8091c1d21..5a19bec9a 100644 --- a/streaming/python/runtime/transfer.py +++ b/streaming/python/runtime/transfer.py @@ -2,8 +2,6 @@ import logging import random from queue import Queue from typing import List -from enum import Enum -from abc import ABC, abstractmethod import ray import ray.streaming._streaming as _streaming @@ -15,7 +13,6 @@ from ray._raylet import PythonFunctionDescriptor from ray._raylet import Language CHANNEL_ID_LEN = 20 -logger = logging.getLogger(__name__) class ChannelID: @@ -100,70 +97,40 @@ def channel_bytes_to_str(id_bytes): return bytes.hex(id_bytes) -class Message(ABC): - @property - @abstractmethod - def body(self): - """Message data""" - pass - - @property - @abstractmethod - def timestamp(self): - """Get timestamp when item is written by upstream DataWriter - """ - pass - - @property - @abstractmethod - def channel_id(self): - """Get string id of channel where data is coming from""" - pass - - @property - @abstractmethod - def message_id(self): - """Get message id of the message""" - pass - - -class DataMessage(Message): +class DataMessage: """ - DataMessage represents data between upstream and downstream operator. + DataMessage represents data between upstream and downstream operator """ def __init__(self, body, timestamp, - message_id, channel_id, + message_id_, is_empty_message=False): self.__body = body self.__timestamp = timestamp self.__channel_id = channel_id - self.__message_id = message_id + self.__message_id = message_id_ self.__is_empty_message = is_empty_message def __len__(self): return len(self.__body) - @property def body(self): + """Message data""" return self.__body - @property def timestamp(self): + """Get timestamp when item is written by upstream DataWriter + """ return self.__timestamp - @property def channel_id(self): + """Get string id of channel where data is coming from + """ return self.__channel_id - @property - def message_id(self): - return self.__message_id - - @property def is_empty_message(self): """Whether this message is an empty message. Upstream DataWriter will send an empty message when this is no data @@ -171,47 +138,10 @@ class DataMessage(Message): """ return self.__is_empty_message - -class CheckpointBarrier(Message): - """ - CheckpointBarrier separates the records in the data stream into the set of - records that goes into the current snapshot, and the records that go into - the next snapshot. Each barrier carries the ID of the snapshot whose - records it pushed in front of it. - """ - - def __init__(self, barrier_data, timestamp, message_id, channel_id, - offsets, barrier_id, barrier_type): - self.__barrier_data = barrier_data - self.__timestamp = timestamp - self.__message_id = message_id - self.__channel_id = channel_id - self.checkpoint_id = barrier_id - self.offsets = offsets - self.barrier_type = barrier_type - - @property - def body(self): - return self.__barrier_data - - @property - def timestamp(self): - return self.__timestamp - - @property - def channel_id(self): - return self.__channel_id - @property def message_id(self): return self.__message_id - def get_input_checkpoints(self): - return self.offsets - - def __str__(self): - return "Barrier(Checkpoint id : {})".format(self.checkpoint_id) - class ChannelCreationParametersBuilder: """ @@ -288,6 +218,9 @@ class ChannelCreationParametersBuilder: _python_reader_sync_function_descriptor = sync_function +logger = logging.getLogger(__name__) + + class DataWriter: """Data Writer is a wrapper of streaming c++ DataWriter, which sends data to downstream workers @@ -331,26 +264,6 @@ class DataWriter: msg_id = self.writer.write(channel_id.object_qid, item) return msg_id - def broadcast_barrier(self, checkpoint_id: int, body: bytes): - """Broadcast barriers to all downstream channels - Args: - checkpoint_id: the checkpoint_id - body: barrier payload - """ - self.writer.broadcast_barrier(checkpoint_id, body) - - def get_output_checkpoints(self) -> List[int]: - """Get output offsets of all downstream channels - Returns: - a list contains current msg_id of each downstream channel - """ - return self.writer.get_output_checkpoints() - - def clear_checkpoint(self, checkpoint_id): - logger.info("producer start to clear checkpoint, checkpoint_id={}" - .format(checkpoint_id)) - self.writer.clear_checkpoint(checkpoint_id) - def stop(self): logger.info("stopping channel writer.") self.writer.stop() @@ -381,20 +294,18 @@ class DataReader: ] creation_parameters = ChannelCreationParametersBuilder() creation_parameters.build_input_queue_parameters(from_actors) + py_seq_ids = [0 for _ in range(len(input_channels))] py_msg_ids = [0 for _ in range(len(input_channels))] timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1)) + is_recreate = bool(conf.get(Config.IS_RECREATE, False)) config_bytes = _to_native_conf(conf) self.__queue = Queue(10000) is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL - self.reader, queues_creation_status = _streaming.DataReader.create( + self.reader = _streaming.DataReader.create( py_input_channels, creation_parameters.get_parameters(), - py_msg_ids, timer_interval, config_bytes, is_mock) - - self.__creation_status = {} - for q, status in queues_creation_status.items(): - self.__creation_status[q] = ChannelCreationStatus(status) - logger.info("create DataReader succeed, creation_status={}".format( - self.__creation_status)) + py_seq_ids, py_msg_ids, timer_interval, is_recreate, config_bytes, + is_mock) + logger.info("create DataReader succeed") def read(self, timeout_millis): """Read data from channel @@ -405,17 +316,16 @@ class DataReader: channel item """ if self.__queue.empty(): - messages = self.reader.read(timeout_millis) - for message in messages: - self.__queue.put(message) - + msgs = self.reader.read(timeout_millis) + for msg in msgs: + msg_bytes, msg_id, timestamp, qid_bytes = msg + data_msg = DataMessage(msg_bytes, timestamp, + channel_bytes_to_str(qid_bytes), msg_id) + self.__queue.put(data_msg) if self.__queue.empty(): return None return self.__queue.get() - def get_channel_recover_info(self): - return ChannelRecoverInfo(self.__creation_status) - def stop(self): logger.info("stopping Data Reader.") self.reader.stop() @@ -462,45 +372,3 @@ class ChannelInitException(Exception): class ChannelInterruptException(Exception): def __init__(self, msg=None): self.msg = msg - - -class ChannelRecoverInfo: - def __init__(self, queue_creation_status_map=None): - if queue_creation_status_map is None: - queue_creation_status_map = {} - self.__queue_creation_status_map = queue_creation_status_map - - def get_creation_status(self): - return self.__queue_creation_status_map - - def get_data_lost_queues(self): - data_lost_queues = set() - for (q, status) in self.__queue_creation_status_map.items(): - if status == ChannelCreationStatus.DataLost: - data_lost_queues.add(q) - return data_lost_queues - - def __str__(self): - return "QueueRecoverInfo [dataLostQueues=%s]" \ - % (self.get_data_lost_queues()) - - -class ChannelCreationStatus(Enum): - FreshStarted = 0 - PullOk = 1 - Timeout = 2 - DataLost = 3 - - -def channel_id_bytes_to_str(id_bytes): - """ - Args: - id_bytes: bytes representation of channel id - - Returns: - string representation of channel id - """ - assert type(id_bytes) in [str, bytes] - if isinstance(id_bytes, str): - return id_bytes - return bytes.hex(id_bytes) diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py index d6d8eb02b..0fdb56096 100644 --- a/streaming/python/runtime/worker.py +++ b/streaming/python/runtime/worker.py @@ -1,23 +1,12 @@ -import enum -import logging.config -import os -import threading -import time -from typing import Optional +import logging import ray -import ray.streaming.runtime.processor as processor -from ray.actor import ActorHandle -from ray.streaming.generated import remote_call_pb2 -from ray.streaming.runtime.command import WorkerRollbackRequest -from ray.streaming.runtime.failover import Barrier -from ray.streaming.runtime.graph import ExecutionVertexContext, ExecutionVertex -from ray.streaming.runtime.remote_call import CallResult, RemoteCallMst -from ray.streaming.runtime.context_backend import ContextBackendFactory -from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask -from ray.streaming.runtime.transfer import channel_bytes_to_str -from ray.streaming.config import Config import ray.streaming._streaming as _streaming +import ray.streaming.generated.remote_call_pb2 as remote_call_pb +import ray.streaming.runtime.processor as processor +from ray.streaming.config import Config +from ray.streaming.runtime.graph import ExecutionVertexContext +from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask logger = logging.getLogger(__name__) @@ -29,179 +18,74 @@ _NOT_READY_FLAG_ = b" " * 4 class JobWorker(object): """A streaming job worker is used to execute user-defined function and interact with `JobMaster`""" - master_actor: Optional[ActorHandle] - worker_context: Optional[remote_call_pb2.PythonJobWorkerContext] - execution_vertex_context: Optional[ExecutionVertexContext] - __need_rollback: bool - def __init__(self, execution_vertex_pb_bytes): - logger.info("Creating job worker, pid={}".format(os.getpid())) - execution_vertex_pb = remote_call_pb2\ - .ExecutionVertexContext.ExecutionVertex() - execution_vertex_pb.ParseFromString(execution_vertex_pb_bytes) - self.execution_vertex = ExecutionVertex(execution_vertex_pb) - self.config = self.execution_vertex.config + def __init__(self): self.worker_context = None self.execution_vertex_context = None + self.config = None self.task_id = None self.task = None self.stream_processor = None - self.master_actor = None - self.context_backend = ContextBackendFactory.get_context_backend( - self.config) - self.initial_state_lock = threading.Lock() - self.__rollback_cnt: int = 0 - self.__is_recreate: bool = False - self.__state = WorkerState() - self.__need_rollback = True self.reader_client = None self.writer_client = None - try: - # load checkpoint - was_reconstructed = ray.get_runtime_context( - ).was_current_actor_reconstructed - - logger.info( - "Worker was reconstructed: {}".format(was_reconstructed)) - if was_reconstructed: - job_worker_context_key = self.__get_job_worker_context_key() - logger.info("Worker get checkpoint state by key: {}.".format( - job_worker_context_key)) - context_bytes = self.context_backend.get( - job_worker_context_key) - if context_bytes is not None and context_bytes.__len__() > 0: - self.init(context_bytes) - self.request_rollback( - "Python worker recover from checkpoint.") - else: - logger.error( - "Error! Worker get checkpoint state by key {}" - " returns None, please check your state backend" - ", only reliable state backend supports fail-over." - .format(job_worker_context_key)) - except Exception: - logger.exception("Error in __init__ of JobWorker") - logger.info("Creating job worker succeeded. worker config {}".format( - self.config)) + logger.info("Creating job worker succeeded.") def init(self, worker_context_bytes): - logger.info("Start to init job worker") - try: - # deserialize context - worker_context = remote_call_pb2.PythonJobWorkerContext() - worker_context.ParseFromString(worker_context_bytes) - self.worker_context = worker_context - self.master_actor = ActorHandle._deserialization_helper( - worker_context.master_actor) + worker_context = remote_call_pb.PythonJobWorkerContext() + worker_context.ParseFromString(worker_context_bytes) + self.worker_context = worker_context - # build vertex context from pb - self.execution_vertex_context = ExecutionVertexContext( - worker_context.execution_vertex_context) - self.execution_vertex = self\ - .execution_vertex_context.execution_vertex + # build vertex context from pb + self.execution_vertex_context = ExecutionVertexContext( + worker_context.execution_vertex_context) - # save context - job_worker_context_key = self.__get_job_worker_context_key() - self.context_backend.put(job_worker_context_key, - worker_context_bytes) + # use vertex id as task id + self.task_id = self.execution_vertex_context.get_task_id() - # use vertex id as task id - self.task_id = self.execution_vertex_context.get_task_id() - # build and get processor from operator - operator = self.execution_vertex_context.stream_operator - self.stream_processor = processor.build_processor(operator) - logger.info("Initializing job worker, exe_vertex_name={}," - "task_id: {}, operator: {}, pid={}".format( - self.execution_vertex_context.exe_vertex_name, - self.task_id, self.stream_processor, os.getpid())) + # build and get processor from operator + operator = self.execution_vertex_context.stream_operator + self.stream_processor = processor.build_processor(operator) + logger.info( + "Initializing job worker, task_id: {}, operator: {}.".format( + self.task_id, self.stream_processor)) - # get config from vertex - self.config = self.execution_vertex_context.config + # get config from vertex + self.config = self.execution_vertex_context.config - if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): - self.reader_client = _streaming.ReaderClient() - self.writer_client = _streaming.WriterClient() + if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): + self.reader_client = _streaming.ReaderClient() + self.writer_client = _streaming.WriterClient() - logger.info("Job worker init succeeded.") - except Exception: - logger.exception("Error when init job worker.") - return False + self.task = self.create_stream_task() + + logger.info("Job worker init succeeded.") return True - def create_stream_task(self, checkpoint_id): + def start(self): + self.task.start() + logger.info("Job worker start succeeded.") + + def create_stream_task(self): if isinstance(self.stream_processor, processor.SourceProcessor): - return SourceStreamTask(self.task_id, self.stream_processor, self, - checkpoint_id) + return SourceStreamTask(self.task_id, self.stream_processor, self) elif isinstance(self.stream_processor, processor.OneInputProcessor): return OneInputStreamTask(self.task_id, self.stream_processor, - self, checkpoint_id) + self) else: raise Exception("Unsupported processor type: " + - str(type(self.stream_processor))) + type(self.stream_processor)) - def rollback(self, checkpoint_id_bytes): - checkpoint_id_pb = remote_call_pb2.CheckpointId() - checkpoint_id_pb.ParseFromString(checkpoint_id_bytes) - checkpoint_id = checkpoint_id_pb.checkpoint_id - - logger.info("Start rollback, checkpoint_id={}".format(checkpoint_id)) - - self.__rollback_cnt += 1 - if self.__rollback_cnt > 1: - self.__is_recreate = True - # skip useless rollback - self.initial_state_lock.acquire() - try: - if self.task is not None and self.task.thread.is_alive()\ - and checkpoint_id == self.task.last_checkpoint_id\ - and self.task.is_initial_state: - logger.info( - "Task is already in initial state, skip this rollback.") - return self.__gen_call_result( - CallResult.skipped( - "Task is already in initial state, skip this rollback." - )) - finally: - self.initial_state_lock.release() - - # restart task - try: - if self.task is not None: - # make sure the runner is closed - self.task.cancel_task() - del self.task - - self.task = self.create_stream_task(checkpoint_id) - - q_recover_info = self.task.recover(self.__is_recreate) - - self.__state.set_type(StateType.RUNNING) - self.__need_rollback = False - - logger.info( - "Rollback success, checkpoint is {}, qRecoverInfo is {}.". - format(checkpoint_id, q_recover_info)) - - return self.__gen_call_result(CallResult.success(q_recover_info)) - except Exception: - logger.exception("Rollback has exception.") - return self.__gen_call_result(CallResult.fail()) - - def on_reader_message(self, *buffers): + def on_reader_message(self, buffer: bytes): """Called by upstream queue writer to send data message to downstream queue reader. """ - if self.reader_client is None: - logger.info("reader_client is None, skip writer transfer") - return - self.reader_client.on_reader_message(*buffers) + self.reader_client.on_reader_message(buffer) def on_reader_message_sync(self, buffer: bytes): - """Called by upstream queue writer to send - control message to downstream downstream queue reader. + """Called by upstream queue writer to send control message to downstream + downstream queue reader. """ if self.reader_client is None: - logger.info("task is None, skip reader transfer") return _NOT_READY_FLAG_ result = self.reader_client.on_reader_message_sync(buffer) return result.to_pybytes() @@ -210,9 +94,6 @@ class JobWorker(object): """Called by downstream queue reader to send notify message to upstream queue writer. """ - if self.writer_client is None: - logger.info("writer_client is None, skip writer transfer") - return self.writer_client.on_writer_message(buffer) def on_writer_message_sync(self, buffer: bytes): @@ -223,164 +104,3 @@ class JobWorker(object): return _NOT_READY_FLAG_ result = self.writer_client.on_writer_message_sync(buffer) return result.to_pybytes() - - def shutdown_without_reconstruction(self): - logger.info("Python worker shutdown without reconstruction.") - ray.actor.exit_actor() - - def notify_checkpoint_timeout(self, checkpoint_id_bytes): - pass - - def commit(self, barrier_bytes): - barrier_pb = remote_call_pb2.Barrier() - barrier_pb.ParseFromString(barrier_bytes) - barrier = Barrier(barrier_pb.id) - logger.info("Receive trigger, barrier is {}.".format(barrier)) - - if self.task is not None: - self.task.commit_trigger(barrier) - ret = remote_call_pb2.BoolResult() - ret.boolRes = True - return ret.SerializeToString() - - def clear_expired_cp(self, state_checkpoint_id_bytes, - queue_checkpoint_id_bytes): - state_checkpoint_id = self.__parse_to_checkpoint_id( - state_checkpoint_id_bytes) - queue_checkpoint_id = self.__parse_to_checkpoint_id( - queue_checkpoint_id_bytes) - logger.info("Start to clear expired checkpoint, checkpoint_id={}," - "queue_checkpoint_id={}, exe_vertex_name={}.".format( - state_checkpoint_id, queue_checkpoint_id, - self.execution_vertex_context.exe_vertex_name)) - - ret = remote_call_pb2.BoolResult() - ret.boolRes = self.__clear_expired_cp_state(state_checkpoint_id) \ - if state_checkpoint_id > 0 else True - ret.boolRes &= self.__clear_expired_queue_msg(queue_checkpoint_id) - logger.info( - "Clear expired checkpoint done, result={}, checkpoint_id={}," - "queue_checkpoint_id={}, exe_vertex_name={}.".format( - ret.boolRes, state_checkpoint_id, queue_checkpoint_id, - self.execution_vertex_context.exe_vertex_name)) - return ret.SerializeToString() - - def __clear_expired_cp_state(self, checkpoint_id): - if self.__need_rollback: - logger.warning("Need rollback, skip clear_expired_cp_state" - ", checkpoint id: {}".format(checkpoint_id)) - return False - - logger.info("Clear expired checkpoint state, cp id is {}.".format( - checkpoint_id)) - - if self.task is not None: - self.task.clear_expired_cp_state(checkpoint_id) - return True - - def __clear_expired_queue_msg(self, checkpoint_id): - if self.__need_rollback: - logger.warning("Need rollback, skip clear_expired_queue_msg" - ", checkpoint id: {}".format(checkpoint_id)) - return False - - logger.info("Clear expired queue msg, checkpoint_id is {}.".format( - checkpoint_id)) - - if self.task is not None: - self.task.clear_expired_queue_msg(checkpoint_id) - return True - - def __parse_to_checkpoint_id(self, checkpoint_id_bytes): - checkpoint_id_pb = remote_call_pb2.CheckpointId() - checkpoint_id_pb.ParseFromString(checkpoint_id_bytes) - return checkpoint_id_pb.checkpoint_id - - def check_if_need_rollback(self): - ret = remote_call_pb2.BoolResult() - ret.boolRes = self.__need_rollback - return ret.SerializeToString() - - def request_rollback(self, exception_msg="Python exception."): - logger.info("Request rollback.") - - self.__need_rollback = True - self.__is_recreate = True - - request_ret = False - for i in range(Config.REQUEST_ROLLBACK_RETRY_TIMES): - logger.info("request rollback {} time".format(i)) - try: - request_ret = RemoteCallMst.request_job_worker_rollback( - self.master_actor, - WorkerRollbackRequest( - self.execution_vertex_context.actor_id.binary(), - "Exception msg=%s, retry time=%d." % (exception_msg, - i))) - except Exception: - logger.exception("Unexpected error when rollback") - logger.info("request rollback {} time, ret={}".format( - i, request_ret)) - if not request_ret: - logger.warning( - "Request rollback return false" - ", maybe it's invalid request, try to sleep 1s.") - time.sleep(1) - else: - break - if not request_ret: - logger.warning("Request failed after retry {} times," - "now worker shutdown without reconstruction." - .format(Config.REQUEST_ROLLBACK_RETRY_TIMES)) - self.shutdown_without_reconstruction() - - self.__state.set_type(StateType.WAIT_ROLLBACK) - - def __gen_call_result(self, call_result): - call_result_pb = remote_call_pb2.CallResult() - - call_result_pb.success = call_result.success - call_result_pb.result_code = call_result.result_code.value - if call_result.result_msg is not None: - call_result_pb.result_msg = call_result.result_msg - - if call_result.result_obj is not None: - q_recover_info = call_result.result_obj - for q, status in q_recover_info.get_creation_status().items(): - call_result_pb.result_obj.creation_status[channel_bytes_to_str( - q)] = status.value - - return call_result_pb.SerializeToString() - - def _gen_unique_key(self, key_prefix): - return key_prefix \ - + str(self.config.get(Config.STREAMING_JOB_NAME)) \ - + "_" + str(self.execution_vertex.execution_vertex_id) - - def __get_job_worker_context_key(self) -> str: - return self._gen_unique_key(Config.JOB_WORKER_CONTEXT_KEY) - - -class WorkerState: - """ - worker state - """ - - def __init__(self): - self.__type = StateType.INIT - - def set_type(self, type): - self.__type = type - - def get_type(self): - return self.__type - - -class StateType(enum.Enum): - """ - state type - """ - - INIT = 1 - RUNNING = 2 - WAIT_ROLLBACK = 3 diff --git a/streaming/python/tests/test_direct_transfer.py b/streaming/python/tests/test_direct_transfer.py index 9a9f2892c..5a8866d22 100644 --- a/streaming/python/tests/test_direct_transfer.py +++ b/streaming/python/tests/test_direct_transfer.py @@ -68,7 +68,7 @@ class Worker: if item is None: time.sleep(0.01) else: - msg = pickle.loads(item.body) + msg = pickle.loads(item.body()) count += 1 assert msg == msg_nums - 1 print("ReaderWorker done.") diff --git a/streaming/python/tests/test_failover.py b/streaming/python/tests/test_failover.py deleted file mode 100644 index def93f43e..000000000 --- a/streaming/python/tests/test_failover.py +++ /dev/null @@ -1,107 +0,0 @@ -import subprocess -import time -from typing import List - -import ray -from ray.streaming import StreamingContext - - -def test_word_count(): - try: - ray.init(_load_code_from_local=True, _include_java=True) - # time.sleep(10) # for gdb to attach - ctx = StreamingContext.Builder() \ - .option("streaming.context-backend.type", "local_file") \ - .option( - "streaming.context-backend.file-state.root", - "/tmp/ray/cp_files/" - ) \ - .option("streaming.checkpoint.timeout.secs", "3") \ - .build() - - print("-----------submit job-------------") - - ctx.read_text_file(__file__) \ - .set_parallelism(1) \ - .flat_map(lambda x: x.split()) \ - .map(lambda x: (x, 1)) \ - .key_by(lambda x: x[0]) \ - .reduce(lambda old_value, new_value: - (old_value[0], old_value[1] + new_value[1])) \ - .filter(lambda x: "ray" not in x) \ - .sink(lambda x: print("####result", x)) - ctx.submit("word_count") - - print("-----------checking output-------------") - retry_count = 180 / 5 # wait for 3min - while not has_sink_output(): - time.sleep(5) - retry_count -= 1 - if retry_count <= 0: - raise RuntimeError("Can not find output") - - print("-----------killing worker-------------") - time.sleep(5) - kill_all_worker() - - print("-----------checking checkpoint-------------") - cp_ok_num = checkpoint_success_num() - retry_count = 300000 / 5 # wait for 5min - while True: - cur_cp_num = checkpoint_success_num() - print("-----------checking checkpoint" - ", cur_cp_num={}, old_cp_num={}-------------".format( - cur_cp_num, cp_ok_num)) - if cur_cp_num > cp_ok_num: - print("--------------TEST OK!------------------") - break - time.sleep(5) - retry_count -= 1 - if retry_count <= 0: - raise RuntimeError( - "Checkpoint keeps failing after fail-over, test failed!") - finally: - ray.shutdown() - - -def run_cmd(cmd: List): - try: - out = subprocess.check_output(cmd).decode() - except subprocess.CalledProcessError as e: - out = str(e) - return out - - -def grep_log(keyword: str) -> str: - out = subprocess.check_output( - ["grep", "-r", keyword, "/tmp/ray/session_latest/logs"]) - return out.decode() - - -def has_sink_output() -> bool: - try: - grep_log("####result") - return True - except Exception: - return False - - -def checkpoint_success_num() -> int: - try: - return grep_log("Finish checkpoint").count("\n") - except Exception: - return 0 - - -def kill_all_worker(): - cmd = [ - "bash", "-c", "grep -r \'Initializing job worker, exe_vert\' " - " /tmp/ray/session_latest/logs | awk -F\'pid\' \'{print $2}\'" - "| awk -F\'=\' \'{print $2}\'" + "| xargs kill -9" - ] - print(cmd) - return subprocess.run(cmd) - - -if __name__ == "__main__": - test_word_count() diff --git a/streaming/src/channel/channel.cc b/streaming/src/channel.cc similarity index 73% rename from streaming/src/channel/channel.cc rename to streaming/src/channel.cc index 896ea2169..6816bf972 100644 --- a/streaming/src/channel/channel.cc +++ b/streaming/src/channel.cc @@ -25,10 +25,30 @@ StreamingQueueProducer::~StreamingQueueProducer() { StreamingStatus StreamingQueueProducer::CreateTransferChannel() { CreateQueue(); - STREAMING_LOG(WARNING) << "Message id in channel => " - << channel_info_.current_message_id; + uint64_t queue_last_seq_id = 0; + uint64_t last_message_id_in_queue = 0; - channel_info_.message_last_commit_id = 0; + if (!last_message_id_in_queue) { + if (last_message_id_in_queue < channel_info_.current_message_id) { + STREAMING_LOG(WARNING) << "last message id in queue : " << last_message_id_in_queue + << " is less than message checkpoint loaded id : " + << channel_info_.current_message_id + << ", an old queue object " << channel_info_.channel_id + << " was fond in store"; + } + last_message_id_in_queue = channel_info_.current_message_id; + } + if (queue_last_seq_id == static_cast(-1)) { + queue_last_seq_id = 0; + } + channel_info_.current_seq_id = queue_last_seq_id; + + STREAMING_LOG(WARNING) << "existing last message id => " << last_message_id_in_queue + << ", message id in channel => " + << channel_info_.current_message_id << ", queue last seq id => " + << queue_last_seq_id; + + channel_info_.message_last_commit_id = last_message_id_in_queue; return StreamingStatus::OK; } @@ -49,8 +69,11 @@ StreamingStatus StreamingQueueProducer::CreateQueue() { channel_info_.queue_size); STREAMING_CHECK(queue_ != nullptr); - STREAMING_LOG(INFO) << "StreamingQueueProducer CreateQueue queue id => " - << channel_info_.channel_id << ", queue size => " + std::vector queue_ids, failed_queues; + queue_ids.push_back(channel_info_.channel_id); + upstream_handler->WaitQueues(queue_ids, 10 * 1000, failed_queues); + + STREAMING_LOG(INFO) << "q id => " << channel_info_.channel_id << ", queue size => " << channel_info_.queue_size; return StreamingStatus::OK; @@ -66,29 +89,21 @@ StreamingStatus StreamingQueueProducer::ClearTransferCheckpoint( } StreamingStatus StreamingQueueProducer::RefreshChannelInfo() { - channel_info_.queue_info.consumed_message_id = queue_->GetMinConsumedMsgID(); + channel_info_.queue_info.consumed_seq_id = queue_->GetMinConsumedSeqID(); return StreamingStatus::OK; } -StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t msg_id) { - queue_->SetQueueEvictionLimit(msg_id); +StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t channel_offset) { + queue_->SetQueueEvictionLimit(channel_offset); return StreamingStatus::OK; } StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { - StreamingMessageBundleMetaPtr meta = StreamingMessageBundleMeta::FromBytes(data); - uint64_t msg_id_end = meta->GetLastMessageId(); - uint64_t msg_id_start = - (meta->GetMessageListSize() == 0 ? msg_id_end - : msg_id_end - meta->GetMessageListSize() + 1); + /// TODO: Fix msg_id_start and msg_id_end + Status status = PushQueueItem(channel_info_.current_seq_id + 1, data, data_size, + current_time_ms(), 0, 0); - STREAMING_LOG(DEBUG) << "ProduceItemToChannel, qid=" << channel_info_.channel_id - << ", msg_id_start=" << msg_id_start - << ", msg_id_end=" << msg_id_end << ", meta=" << *meta; - - Status status = - PushQueueItem(data, data_size, current_time_ms(), msg_id_start, msg_id_end); if (status.code() != StatusCode::OK) { STREAMING_LOG(DEBUG) << channel_info_.channel_id << " => Queue is full" << " meesage => " << status.message(); @@ -105,14 +120,14 @@ StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, return StreamingStatus::OK; } -Status StreamingQueueProducer::PushQueueItem(uint8_t *data, uint32_t data_size, - uint64_t timestamp, uint64_t msg_id_start, - uint64_t msg_id_end) { +Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data, + uint32_t data_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end) { STREAMING_LOG(DEBUG) << "StreamingQueueProducer::PushQueueItem:" - << " qid: " << channel_info_.channel_id + << " qid: " << channel_info_.channel_id << " seq_id: " << seq_id << " data_size: " << data_size; Status status = - queue_->Push(data, data_size, timestamp, msg_id_start, msg_id_end, false); + queue_->Push(seq_id, data, data_size, timestamp, msg_id_start, msg_id_end, false); if (status.IsOutOfMemory()) { status = queue_->TryEvictItems(); if (!status.ok()) { @@ -120,7 +135,8 @@ Status StreamingQueueProducer::PushQueueItem(uint8_t *data, uint32_t data_size, return status; } - status = queue_->Push(data, data_size, timestamp, msg_id_start, msg_id_end, false); + status = + queue_->Push(seq_id, data, data_size, timestamp, msg_id_start, msg_id_end, false); } queue_->Send(); @@ -162,7 +178,7 @@ StreamingQueueStatus StreamingQueueConsumer::GetQueue( TransferCreationStatus StreamingQueueConsumer::CreateTransferChannel() { StreamingQueueStatus status = - GetQueue(channel_info_.channel_id, channel_info_.current_message_id + 1, + GetQueue(channel_info_.channel_id, channel_info_.current_seq_id + 1, channel_info_.parameter); if (status == StreamingQueueStatus::OK) { @@ -188,11 +204,12 @@ StreamingStatus StreamingQueueConsumer::ClearTransferCheckpoint( } StreamingStatus StreamingQueueConsumer::RefreshChannelInfo() { - channel_info_.queue_info.last_message_id = queue_->GetLastRecvMsgId(); + channel_info_.queue_info.last_seq_id = queue_->GetLastRecvSeqId(); return StreamingStatus::OK; } -StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint8_t *&data, +StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint64_t &offset_id, + uint8_t *&data, uint32_t &data_size, uint32_t timeout) { STREAMING_LOG(INFO) << "GetQueueItem qid: " << channel_info_.channel_id; @@ -202,14 +219,16 @@ StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint8_t *&data, STREAMING_LOG(INFO) << "GetQueueItem timeout."; data = nullptr; data_size = 0; + offset_id = QUEUE_INVALID_SEQ_ID; return StreamingStatus::OK; } data = item.Buffer()->Data(); + offset_id = item.SeqId(); data_size = item.Buffer()->Size(); STREAMING_LOG(DEBUG) << "GetQueueItem qid: " << channel_info_.channel_id - << " seq_id: " << item.SeqId() << " msg_id: " << item.MaxMsgId() + << " seq_id: " << offset_id << " msg_id: " << item.MaxMsgId() << " data_size: " << data_size; return StreamingStatus::OK; } @@ -230,7 +249,7 @@ struct MockQueueItem { class MockQueue { public: std::unordered_map>> - message_buffer; + message_bffer; std::unordered_map>> consumed_buffer; std::unordered_map queue_info_map; @@ -245,7 +264,7 @@ std::mutex MockQueue::mutex; StreamingStatus MockProducer::CreateTransferChannel() { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); - mock_queue.message_buffer[channel_info_.channel_id] = + mock_queue.message_bffer[channel_info_.channel_id] = std::make_shared>(10000); mock_queue.consumed_buffer[channel_info_.channel_id] = std::make_shared>(10000); @@ -255,7 +274,7 @@ StreamingStatus MockProducer::CreateTransferChannel() { StreamingStatus MockProducer::DestroyTransferChannel() { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); - mock_queue.message_buffer.erase(channel_info_.channel_id); + mock_queue.message_bffer.erase(channel_info_.channel_id); mock_queue.consumed_buffer.erase(channel_info_.channel_id); return StreamingStatus::OK; } @@ -263,39 +282,44 @@ StreamingStatus MockProducer::DestroyTransferChannel() { StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); - auto &ring_buffer = mock_queue.message_buffer[channel_info_.channel_id]; + auto &ring_buffer = mock_queue.message_bffer[channel_info_.channel_id]; if (ring_buffer->Full()) { return StreamingStatus::OutOfMemory; } MockQueueItem item; + item.seq_id = channel_info_.current_seq_id + 1; item.data.reset(new uint8_t[data_size]); item.data_size = data_size; std::memcpy(item.data.get(), data, data_size); ring_buffer->Push(item); + mock_queue.queue_info_map[channel_info_.channel_id].last_seq_id = item.seq_id; return StreamingStatus::OK; } StreamingStatus MockProducer::RefreshChannelInfo() { MockQueue &mock_queue = MockQueue::GetMockQueue(); - channel_info_.queue_info.consumed_message_id = - mock_queue.queue_info_map[channel_info_.channel_id].consumed_message_id; + channel_info_.queue_info.consumed_seq_id = + mock_queue.queue_info_map[channel_info_.channel_id].consumed_seq_id; return StreamingStatus::OK; } -StreamingStatus MockConsumer::ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, +StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, uint32_t timeout) { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); auto &channel_id = channel_info_.channel_id; - if (mock_queue.message_buffer.find(channel_id) == mock_queue.message_buffer.end()) { + if (mock_queue.message_bffer.find(channel_id) == mock_queue.message_bffer.end()) { return StreamingStatus::NoSuchItem; } - if (mock_queue.message_buffer[channel_id]->Empty()) { + + if (mock_queue.message_bffer[channel_id]->Empty()) { return StreamingStatus::NoSuchItem; } - MockQueueItem item = mock_queue.message_buffer[channel_id]->Front(); - mock_queue.message_buffer[channel_id]->Pop(); + MockQueueItem item = mock_queue.message_bffer[channel_id]->Front(); + mock_queue.message_bffer[channel_id]->Pop(); mock_queue.consumed_buffer[channel_id]->Push(item); + offset_id = item.seq_id; data = item.data.get(); data_size = item.data_size; return StreamingStatus::OK; @@ -309,14 +333,14 @@ StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) { while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) { ring_buffer->Pop(); } - mock_queue.queue_info_map[channel_id].consumed_message_id = offset_id; + mock_queue.queue_info_map[channel_id].consumed_seq_id = offset_id; return StreamingStatus::OK; } StreamingStatus MockConsumer::RefreshChannelInfo() { MockQueue &mock_queue = MockQueue::GetMockQueue(); - channel_info_.queue_info.last_message_id = - mock_queue.queue_info_map[channel_info_.channel_id].last_message_id; + channel_info_.queue_info.last_seq_id = + mock_queue.queue_info_map[channel_info_.channel_id].last_seq_id; return StreamingStatus::OK; } diff --git a/streaming/src/channel/channel.h b/streaming/src/channel.h similarity index 89% rename from streaming/src/channel/channel.h rename to streaming/src/channel.h index 733a07042..4bb46aa15 100644 --- a/streaming/src/channel/channel.h +++ b/streaming/src/channel.h @@ -1,10 +1,9 @@ #pragma once -#include "common/status.h" #include "config/streaming_config.h" #include "queue/queue_handler.h" -#include "ring_buffer/ring_buffer.h" -#include "util/config.h" +#include "ring_buffer.h" +#include "status.h" #include "util/streaming_util.h" namespace ray { @@ -20,9 +19,9 @@ enum class TransferCreationStatus : uint32_t { struct StreamingQueueInfo { uint64_t first_seq_id = 0; - uint64_t last_message_id = 0; - uint64_t target_message_id = 0; - uint64_t consumed_message_id = 0; + uint64_t last_seq_id = 0; + uint64_t target_seq_id = 0; + uint64_t consumed_seq_id = 0; }; struct ChannelCreationParameter { @@ -37,6 +36,7 @@ struct ProducerChannelInfo { ObjectID channel_id; StreamingRingBufferPtr writer_ring_buffer; uint64_t current_message_id; + uint64_t current_seq_id; uint64_t message_last_commit_id; StreamingQueueInfo queue_info; uint32_t queue_size; @@ -58,6 +58,7 @@ struct ProducerChannelInfo { struct ConsumerChannelInfo { ObjectID channel_id; uint64_t current_message_id; + uint64_t current_seq_id; uint64_t barrier_id; uint64_t partial_barrier_id; @@ -70,7 +71,6 @@ struct ConsumerChannelInfo { ChannelCreationParameter parameter; // Total count of notify request. uint64_t notify_cnt = 0; - uint64_t resend_notify_timer; }; /// Two types of channel are presented: @@ -111,7 +111,8 @@ class ConsumerChannel { virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, uint64_t checkpoint_offset) = 0; virtual StreamingStatus RefreshChannelInfo() = 0; - virtual StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, + virtual StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, uint32_t timeout) = 0; virtual StreamingStatus NotifyChannelConsumed(uint64_t offset_id) = 0; @@ -135,8 +136,8 @@ class StreamingQueueProducer : public ProducerChannel { private: StreamingStatus CreateQueue(); - Status PushQueueItem(uint8_t *data, uint32_t data_size, uint64_t timestamp, - uint64_t msg_id_start, uint64_t msg_id_end); + Status PushQueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, + uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end); private: std::shared_ptr queue_; @@ -152,8 +153,8 @@ class StreamingQueueConsumer : public ConsumerChannel { StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, uint64_t checkpoint_offset) override; StreamingStatus RefreshChannelInfo() override; - StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, - uint32_t timeout) override; + StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, uint32_t timeout) override; StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; private: @@ -203,8 +204,8 @@ class MockConsumer : public ConsumerChannel { return StreamingStatus::OK; } StreamingStatus RefreshChannelInfo() override; - StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, - uint32_t timeout) override; + StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, uint32_t timeout) override; StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; }; diff --git a/streaming/src/config/streaming_config.cc b/streaming/src/config/streaming_config.cc index 7dc94c865..f63b00d2e 100644 --- a/streaming/src/config/streaming_config.cc +++ b/streaming/src/config/streaming_config.cc @@ -10,7 +10,6 @@ uint32_t StreamingConfig::DEFAULT_RING_BUFFER_CAPACITY = 500; uint32_t StreamingConfig::DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL = 20; // Time to force clean if barrier in queue, default 0ms const uint32_t StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE = 2048; -const uint32_t StreamingConfig::RESEND_NOTIFY_MAX_INTERVAL = 1000; // ms #define RESET_IF_INT_CONF(KEY, VALUE) \ if (0 != VALUE) { \ diff --git a/streaming/src/config/streaming_config.h b/streaming/src/config/streaming_config.h index 784f7b094..8f0cf2d5b 100644 --- a/streaming/src/config/streaming_config.h +++ b/streaming/src/config/streaming_config.h @@ -9,20 +9,12 @@ namespace ray { namespace streaming { -using ReliabilityLevel = proto::ReliabilityLevel; -using StreamingRole = proto::NodeType; - -#define DECL_GET_SET_PROPERTY(TYPE, NAME, VALUE) \ - TYPE Get##NAME() const { return VALUE; } \ - void Set##NAME(TYPE value) { VALUE = value; } - class StreamingConfig { public: static uint64_t TIME_WAIT_UINT; static uint32_t DEFAULT_RING_BUFFER_CAPACITY; static uint32_t DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL; static const uint32_t MESSAGE_BUNDLE_MAX_SIZE; - static const uint32_t RESEND_NOTIFY_MAX_INTERVAL; private: uint32_t ring_buffer_capacity_ = DEFAULT_RING_BUFFER_CAPACITY; @@ -48,18 +40,12 @@ class StreamingConfig { uint32_t event_driven_flow_control_interval_ = 1; - ReliabilityLevel streaming_strategy_ = ReliabilityLevel::EXACTLY_ONCE; - StreamingRole streaming_role = StreamingRole::TRANSFORM; - public: void FromProto(const uint8_t *, uint32_t size); - inline bool IsAtLeastOnce() const { - return ReliabilityLevel::AT_LEAST_ONCE == streaming_strategy_; - } - inline bool IsExactlyOnce() const { - return ReliabilityLevel::EXACTLY_ONCE == streaming_strategy_; - } +#define DECL_GET_SET_PROPERTY(TYPE, NAME, VALUE) \ + TYPE Get##NAME() const { return VALUE; } \ + void Set##NAME(TYPE value) { VALUE = value; } DECL_GET_SET_PROPERTY(const std::string &, WorkerName, worker_name_) DECL_GET_SET_PROPERTY(const std::string &, OpName, op_name_) @@ -72,8 +58,6 @@ class StreamingConfig { flow_control_type_) DECL_GET_SET_PROPERTY(uint32_t, EventDrivenFlowControlInterval, event_driven_flow_control_interval_) - DECL_GET_SET_PROPERTY(StreamingRole, StreamingRole, streaming_role) - DECL_GET_SET_PROPERTY(ReliabilityLevel, ReliabilityLevel, streaming_strategy_) uint32_t GetRingBufferCapacity() const; /// Note(lingxuan.zlx), RingBufferCapacity's valid range is from 1 to diff --git a/streaming/src/data_reader.cc b/streaming/src/data_reader.cc index 5697fb003..ca59cf987 100644 --- a/streaming/src/data_reader.cc +++ b/streaming/src/data_reader.cc @@ -18,16 +18,15 @@ const uint32_t DataReader::kReadItemTimeout = 1000; void DataReader::Init(const std::vector &input_ids, const std::vector &init_params, + const std::vector &queue_seq_ids, const std::vector &streaming_msg_ids, - std::vector &creation_status, int64_t timer_interval) { Init(input_ids, init_params, timer_interval); for (size_t i = 0; i < input_ids.size(); ++i) { auto &q_id = input_ids[i]; - last_message_id_[q_id] = streaming_msg_ids[i]; + channel_info_map_[q_id].current_seq_id = queue_seq_ids[i]; channel_info_map_[q_id].current_message_id = streaming_msg_ids[i]; } - InitChannel(creation_status); } void DataReader::Init(const std::vector &input_ids, @@ -54,23 +53,19 @@ void DataReader::Init(const std::vector &input_ids, channel_info.last_queue_item_latency = 0; channel_info.last_queue_target_diff = 0; channel_info.get_queue_item_times = 0; - channel_info.resend_notify_timer = 0; } - reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( - runtime_context_->GetConfig(), barrier_helper_, nullptr, this); - /// Make the input id location stable. sort(input_queue_ids_.begin(), input_queue_ids_.end(), [](const ObjectID &a, const ObjectID &b) { return a.Hash() < b.Hash(); }); std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(unready_queue_ids_)); + InitChannel(); } -StreamingStatus DataReader::InitChannel( - std::vector &creation_status) { +StreamingStatus DataReader::InitChannel() { STREAMING_LOG(INFO) << "[Reader] Getting queues. total queue num " - << input_queue_ids_.size() - << ", unready queue num=" << unready_queue_ids_.size(); + << input_queue_ids_.size() << ", unready queue num => " + << unready_queue_ids_.size(); for (const auto &input_channel : unready_queue_ids_) { auto &channel_info = channel_info_map_[input_channel]; @@ -83,10 +78,8 @@ StreamingStatus DataReader::InitChannel( channel_map_.emplace(input_channel, channel); TransferCreationStatus status = channel->CreateTransferChannel(); - creation_status.push_back(status); if (TransferCreationStatus::PullOk != status) { - STREAMING_LOG(ERROR) << "Initialize queue failed, id=" << input_channel - << ", status=" << static_cast(status); + STREAMING_LOG(ERROR) << "Initialize queue failed, id => " << input_channel; } } runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); @@ -94,11 +87,10 @@ StreamingStatus DataReader::InitChannel( return StreamingStatus::OK; } -StreamingStatus DataReader::InitChannelMerger(uint32_t timeout_ms) { +StreamingStatus DataReader::InitChannelMerger() { STREAMING_LOG(INFO) << "[Reader] Initializing queue merger."; // Init reader merger by given comparator when it's first created. - StreamingReaderMsgPtrComparator comparator( - runtime_context_->GetConfig().GetReliabilityLevel()); + StreamingReaderMsgPtrComparator comparator; if (!reader_merger_) { reader_merger_.reset( new PriorityQueue, StreamingReaderMsgPtrComparator>( @@ -108,255 +100,106 @@ StreamingStatus DataReader::InitChannelMerger(uint32_t timeout_ms) { // An old item in merger vector must be evicted before new queue item has been // pushed. if (!unready_queue_ids_.empty() && last_fetched_queue_item_) { - STREAMING_LOG(INFO) << "pop old item from=" << last_fetched_queue_item_->from; - RETURN_IF_NOT_OK(StashNextMessageAndPop(last_fetched_queue_item_, timeout_ms)) + STREAMING_LOG(INFO) << "pop old item from => " << last_fetched_queue_item_->from; + RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_)) last_fetched_queue_item_.reset(); } // Create initial heap for priority queue. - std::vector unready_queue_ids_stashed; for (auto &input_queue : unready_queue_ids_) { std::shared_ptr msg = std::make_shared(); - auto status = GetMessageFromChannel(channel_info_map_[input_queue], msg, timeout_ms, - timeout_ms); - if (StreamingStatus::OK != status) { - STREAMING_LOG(INFO) - << "[Reader] initializing merger, get message from channel timeout, " - << input_queue << ", status => " << static_cast(status); - unready_queue_ids_stashed.push_back(input_queue); - continue; - } + RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info_map_[input_queue], msg)) + channel_info_map_[msg->from].current_seq_id = msg->seq_id; channel_info_map_[msg->from].current_message_id = msg->meta->GetLastMessageId(); reader_merger_->push(msg); } - if (unready_queue_ids_stashed.empty()) { - STREAMING_LOG(INFO) << "[Reader] Initializing merger done."; - return StreamingStatus::OK; - } else { - STREAMING_LOG(INFO) << "[Reader] Initializing merger unfinished."; - unready_queue_ids_ = unready_queue_ids_stashed; - return StreamingStatus::GetBundleTimeOut; - } + STREAMING_LOG(INFO) << "[Reader] Initializing merger done."; + return StreamingStatus::OK; } StreamingStatus DataReader::GetMessageFromChannel(ConsumerChannelInfo &channel_info, - std::shared_ptr &message, - uint32_t timeout_ms, - uint32_t wait_time_ms) { + std::shared_ptr &message) { auto &qid = channel_info.channel_id; - message->from = qid; last_read_q_id_ = qid; - - bool is_valid_bundle = false; - int64_t start_time = current_sys_time_ms(); - STREAMING_LOG(DEBUG) << "GetMessageFromChannel, timeout_ms=" << timeout_ms - << ", wait_time_ms=" << wait_time_ms; - while (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Running && - !is_valid_bundle && current_sys_time_ms() - start_time < timeout_ms) { - STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id=" << qid; - /// In AT_LEAST_ONCE, wait_time_ms is set to 0, means `ConsumeItemFromChannel` - /// will return immediately if no items in queue. At the same time, `timeout_ms` is - /// ignored. - channel_map_[channel_info.channel_id]->ConsumeItemFromChannel( - message->data, message->data_size, wait_time_ms); - + STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id => " << qid; + while (RuntimeStatus::Running == runtime_context_->GetRuntimeStatus() && + !message->data) { + auto status = channel_map_[channel_info.channel_id]->ConsumeItemFromChannel( + message->seq_id, message->data, message->data_size, kReadItemTimeout); channel_info.get_queue_item_times++; if (!message->data) { - RETURN_IF_NOT_OK(reliability_helper_->HandleNoValidItem(channel_info)); - } else { - uint64_t current_time = current_sys_time_ms(); - channel_info.resend_notify_timer = current_time; - // Note(lingxuan.zlx): To find which channel get an invalid data and - // print channel id for debugging. - STREAMING_CHECK(StreamingMessageBundleMeta::CheckBundleMagicNum(message->data)) - << "Magic number invalid, from channel " << channel_info.channel_id; - message->meta = StreamingMessageBundleMeta::FromBytes(message->data); - - is_valid_bundle = true; - if (!runtime_context_->GetConfig().IsAtLeastOnce()) { - // filter message when msg_id doesn't match. - // reader will filter message only when using streaming queue and - // non-at-least-once mode - BundleCheckStatus status = CheckBundle(message); - STREAMING_LOG(DEBUG) << "CheckBundle, result=" << status - << ", last_msg_id=" << last_message_id_[message->from]; - if (status == BundleCheckStatus::BundleToBeSplit) { - SplitBundle(message, last_message_id_[qid]); - } - if (status == BundleCheckStatus::BundleToBeThrown && message->meta->IsBarrier()) { - STREAMING_LOG(WARNING) - << "Throw barrier, msg_id=" << message->meta->GetLastMessageId(); - } - is_valid_bundle = status != BundleCheckStatus::BundleToBeThrown; - } + STREAMING_LOG(DEBUG) << "[Reader] Queue " << qid << " status " << status + << " get item timeout, resend notify " + << channel_info.current_seq_id; + // TODO(lingxuan.zlx): notify consumed when it's timeout. } } if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) { return StreamingStatus::Interrupted; } + STREAMING_LOG(DEBUG) << "[Reader] recevied queue seq id => " << message->seq_id + << ", queue id => " << qid; - if (!is_valid_bundle) { - STREAMING_LOG(DEBUG) << "GetMessageFromChannel timeout, qid=" - << channel_info.channel_id; - return StreamingStatus::GetBundleTimeOut; - } - - STREAMING_LOG(DEBUG) << "[Reader] received message id=" - << message->meta->GetLastMessageId() << ", queue id=" << qid; - last_message_id_[message->from] = message->meta->GetLastMessageId(); + message->from = qid; + message->meta = StreamingMessageBundleMeta::FromBytes(message->data); return StreamingStatus::OK; } -BundleCheckStatus DataReader::CheckBundle(const std::shared_ptr &message) { - uint64_t end_msg_id = message->meta->GetLastMessageId(); - uint64_t start_msg_id = message->meta->IsEmptyMsg() - ? end_msg_id - : end_msg_id - message->meta->GetMessageListSize() + 1; - uint64_t last_msg_id = last_message_id_[message->from]; - - // Writer will keep sending bundles when downstream reader failover. After reader - // recovered, it will receive these bundles whoes msg_id is larger than expected. - if (start_msg_id > last_msg_id + 1) { - return BundleCheckStatus::BundleToBeThrown; - } - if (end_msg_id < last_msg_id + 1) { - // Empty message and barrier's msg_id equals to last message, so we shouldn't throw - // them. - return end_msg_id == last_msg_id && !message->meta->IsBundle() - ? BundleCheckStatus::OkBundle - : BundleCheckStatus::BundleToBeThrown; - } - // Normal bundles. - if (start_msg_id == last_msg_id + 1) { - return BundleCheckStatus::OkBundle; - } - return BundleCheckStatus::BundleToBeSplit; -} - -void DataReader::SplitBundle(std::shared_ptr &message, uint64_t last_msg_id) { - std::list msg_list; - StreamingMessageBundle::GetMessageListFromRawData( - message->data + kMessageBundleHeaderSize, - message->data_size - kMessageBundleHeaderSize, message->meta->GetMessageListSize(), - msg_list); - uint64_t bundle_size = 0; - for (auto it = msg_list.begin(); it != msg_list.end();) { - if ((*it)->GetMessageId() > last_msg_id) { - bundle_size += (*it)->ClassBytesSize(); - it++; - } else { - it = msg_list.erase(it); - } - } - STREAMING_LOG(DEBUG) << "Split message, from_queue_id=" << message->from - << ", start_msg_id=" << msg_list.front()->GetMessageId() - << ", end_msg_id=" << msg_list.back()->GetMessageId(); - // recreate bundle - auto cut_msg_bundle = std::make_shared( - msg_list, message->meta->GetMessageBundleTs(), msg_list.back()->GetMessageId(), - StreamingMessageBundleType::Bundle, bundle_size); - message->Realloc(cut_msg_bundle->ClassBytesSize()); - cut_msg_bundle->ToBytes(message->data); - message->meta = StreamingMessageBundleMeta::FromBytes(message->data); -} - -StreamingStatus DataReader::StashNextMessageAndPop(std::shared_ptr &message, - uint32_t timeout_ms) { - STREAMING_LOG(DEBUG) << "StashNextMessageAndPop, timeout_ms=" << timeout_ms; - - // Get the first message. - message = reader_merger_->top(); - STREAMING_LOG(DEBUG) << "Messages to be poped=" << *message - << ", merger size=" << reader_merger_->size(); - - // Then stash next message from its from queue. +StreamingStatus DataReader::StashNextMessage(std::shared_ptr &message) { + // Push new message into priority queue and record the channel metrics in + // channel info. std::shared_ptr new_msg = std::make_shared(); auto &channel_info = channel_info_map_[message->from]; - RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg, timeout_ms, timeout_ms)) - new_msg->last_barrier_id = channel_info.barrier_id; - reader_merger_->push(new_msg); - STREAMING_LOG(DEBUG) << "New message pushed=" << *new_msg - << ", merger size=" << reader_merger_->size(); - - // Pop message. reader_merger_->pop(); - STREAMING_LOG(DEBUG) << "Message poped, msg=" << *message; - - // Record some metrics. + int64_t cur_time = current_time_ms(); + RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg)) + reader_merger_->push(new_msg); channel_info.last_queue_item_delay = new_msg->meta->GetMessageBundleTs() - message->meta->GetMessageBundleTs(); - channel_info.last_queue_item_latency = current_time_ms() - current_time_ms(); + channel_info.last_queue_item_latency = current_time_ms() - cur_time; return StreamingStatus::OK; } StreamingStatus DataReader::GetMergedMessageBundle(std::shared_ptr &message, - bool &is_valid_break, - uint32_t timeout_ms) { - RETURN_IF_NOT_OK(StashNextMessageAndPop(message, timeout_ms)) - - auto &offset_info = channel_info_map_[message->from]; - uint64_t cur_queue_previous_msg_id = offset_info.current_message_id; - STREAMING_LOG(DEBUG) << "[Reader] [Bundle]" << *message - << ", cur_queue_previous_msg_id=" << cur_queue_previous_msg_id; + bool &is_valid_break) { int64_t cur_time = current_time_ms(); + if (last_fetched_queue_item_) { + RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_)) + } + message = reader_merger_->top(); + last_fetched_queue_item_ = message; + auto &offset_info = channel_info_map_[message->from]; + + uint64_t cur_queue_previous_msg_id = offset_info.current_message_id; + STREAMING_LOG(DEBUG) << "[Reader] [Bundle] from q_id =>" << message->from << "cur => " + << cur_queue_previous_msg_id << ", message list size" + << message->meta->GetMessageListSize() << ", lst message id =>" + << message->meta->GetLastMessageId() << ", q seq id => " + << message->seq_id << ", last barrier id => " << message->data_size + << ", " << message->meta->GetMessageBundleTs(); + if (message->meta->IsBundle()) { last_message_ts_ = cur_time; is_valid_break = true; - } else if (message->meta->IsBarrier() && BarrierAlign(message)) { - last_message_ts_ = cur_time; - is_valid_break = true; - } else if (timer_interval_ != -1 && cur_time - last_message_ts_ >= timer_interval_ && - message->meta->IsEmptyMsg()) { - // Sent empty message when reaching timer_interval + } else if (timer_interval_ != -1 && cur_time - last_message_ts_ > timer_interval_) { + // Throw empty message when reaching timer_interval. last_message_ts_ = cur_time; is_valid_break = true; } offset_info.current_message_id = message->meta->GetLastMessageId(); + offset_info.current_seq_id = message->seq_id; last_bundle_ts_ = message->meta->GetMessageBundleTs(); - STREAMING_LOG(DEBUG) << "[Reader] [Bundle] Get merged message bundle=" << *message - << ", is_valid_break=" << is_valid_break; - last_fetched_queue_item_ = message; + STREAMING_LOG(DEBUG) << "[Reader] [Bundle] message type =>" + << static_cast(message->meta->GetBundleType()) + << " from id => " << message->from << ", queue seq id =>" + << message->seq_id << ", message id => " + << message->meta->GetLastMessageId(); return StreamingStatus::OK; } -bool DataReader::BarrierAlign(std::shared_ptr &message) { - // Arrange barrier action when barrier is arriving. - StreamingBarrierHeader barrier_header; - StreamingMessage::GetBarrierIdFromRawData(message->data + kMessageHeaderSize, - &barrier_header); - uint64_t barrier_id = barrier_header.barrier_id; - auto *barrier_align_cnt = &global_barrier_cnt_; - auto &channel_info = channel_info_map_[message->from]; - // Target count is input vector size (global barrier). - uint32_t target_count = 0; - - channel_info.barrier_id = barrier_header.barrier_id; - target_count = input_queue_ids_.size(); - (*barrier_align_cnt)[barrier_id]++; - // The next message checkpoint is changed if this's barrier message. - STREAMING_LOG(INFO) << "[Reader] [Barrier] get barrier, barrier_id=" << barrier_id - << ", barrier_cnt=" << (*barrier_align_cnt)[barrier_id] - << ", global barrier id=" << barrier_header.barrier_id - << ", from q_id=" << message->from << ", barrier type=" - << static_cast(barrier_header.barrier_type) - << ", target count=" << target_count; - // Notify invoker the last barrier, so that checkpoint or something related can be - // taken right now. - if ((*barrier_align_cnt)[barrier_id] == target_count) { - // map can't be used in multithread (crash in report timer) - barrier_align_cnt->erase(barrier_id); - STREAMING_LOG(INFO) - << "[Reader] [Barrier] last barrier received, return barrier. barrier_id = " - << barrier_id << ", from q_id=" << message->from; - return true; - } - return false; -} - StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms, std::shared_ptr &message) { - STREAMING_LOG(DEBUG) << "GetBundle, timeout_ms=" << timeout_ms; // Notify upstream that last fetched item has been consumed. if (last_fetched_queue_item_) { NotifyConsumed(last_fetched_queue_item_); @@ -379,25 +222,28 @@ StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms, return StreamingStatus::GetBundleTimeOut; } if (!unready_queue_ids_.empty()) { - std::vector creation_status; - StreamingStatus status = InitChannel(creation_status); + StreamingStatus status = InitChannel(); switch (status) { case StreamingStatus::InitQueueFailed: break; + case StreamingStatus::WaitQueueTimeOut: + STREAMING_LOG(ERROR) + << "Wait upstream queue timeout, maybe some actors in deadlock"; + break; default: STREAMING_LOG(INFO) << "Init reader queue in GetBundle"; } if (StreamingStatus::OK != status) { return status; } - RETURN_IF_NOT_OK(InitChannelMerger(timeout_ms)) + RETURN_IF_NOT_OK(InitChannelMerger()) unready_queue_ids_.clear(); auto &merge_vec = reader_merger_->getRawVector(); for (auto &bundle : merge_vec) { - STREAMING_LOG(INFO) << "merger vector item=" << bundle->from; + STREAMING_LOG(INFO) << "merger vector item => " << bundle->from; } } - RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break, timeout_ms)); + RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break)); if (!is_valid_break) { empty_bundle_cnt++; NotifyConsumed(message); @@ -415,12 +261,16 @@ void DataReader::GetOffsetInfo( offset_map = &channel_info_map_; for (auto &offset_info : channel_info_map_) { STREAMING_LOG(INFO) << "[Reader] [GetOffsetInfo], q id " << offset_info.first - << ", message id=" << offset_info.second.current_message_id; + << ", seq id => " << offset_info.second.current_seq_id + << ", message id => " << offset_info.second.current_message_id; } } void DataReader::NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset) { channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset); + if (offset == channel_info.queue_info.last_seq_id) { + STREAMING_LOG(DEBUG) << "notify seq id equal to last seq id => " << offset; + } } DataReader::DataReader(std::shared_ptr &runtime_context) @@ -436,42 +286,37 @@ void DataReader::NotifyConsumed(std::shared_ptr &message) { auto &channel_info = channel_info_map_[message->from]; auto &queue_info = channel_info.queue_info; channel_info.notify_cnt++; - if (queue_info.target_message_id <= message->meta->GetLastMessageId()) { - NotifyConsumedItem(channel_info, message->meta->GetLastMessageId()); + if (queue_info.target_seq_id <= message->seq_id) { + NotifyConsumedItem(channel_info, message->seq_id); channel_map_[channel_info.channel_id]->RefreshChannelInfo(); - if (queue_info.last_message_id != QUEUE_INVALID_SEQ_ID) { - uint64_t original_target_message_id = queue_info.target_message_id; - queue_info.target_message_id = - std::min(queue_info.last_message_id, - message->meta->GetLastMessageId() + - runtime_context_->GetConfig().GetReaderConsumedStep()); + if (queue_info.last_seq_id != QUEUE_INVALID_SEQ_ID) { + uint64_t original_target_seq_id = queue_info.target_seq_id; + queue_info.target_seq_id = std::min( + queue_info.last_seq_id, + message->seq_id + runtime_context_->GetConfig().GetReaderConsumedStep()); channel_info.last_queue_target_diff = - queue_info.target_message_id - original_target_message_id; + queue_info.target_seq_id - original_target_seq_id; } else { STREAMING_LOG(WARNING) << "[Reader] [QueueInfo] channel id " << message->from - << ", last message id " << queue_info.last_message_id; + << ", last seq id " << queue_info.last_seq_id; } STREAMING_LOG(DEBUG) << "[Reader] [Consumed] Trigger notify consumed" - << ", channel id=" << message->from - << ", last message id=" << queue_info.last_message_id - << ", target message id=" << queue_info.target_message_id - << ", consumed message id=" << message->meta->GetLastMessageId() - << ", bundle type=" + << ", channel id => " << message->from << ", last seq id => " + << queue_info.last_seq_id << ", target seq id => " + << queue_info.target_seq_id << ", consumed seq id => " + << message->seq_id << ", last message id => " + << message->meta->GetLastMessageId() << ", bundle type => " << static_cast(message->meta->GetBundleType()) - << ", last message bundle ts=" + << ", last message bundle ts => " << message->meta->GetMessageBundleTs(); } } bool StreamingReaderMsgPtrComparator::operator()(const std::shared_ptr &a, const std::shared_ptr &b) { - if (comp_strategy == ReliabilityLevel::EXACTLY_ONCE) { - if (a->last_barrier_id != b->last_barrier_id) - return a->last_barrier_id > b->last_barrier_id; - } STREAMING_CHECK(a->meta); - // We proposed fixed id sequnce for stability of message in sorting. + // We use hash value of id for stability of message in sorting. if (a->meta->GetMessageBundleTs() == b->meta->GetMessageBundleTs()) { return a->from.Hash() > b->from.Hash(); } diff --git a/streaming/src/data_reader.h b/streaming/src/data_reader.h index 615047ecc..ac0d1c469 100644 --- a/streaming/src/data_reader.h +++ b/streaming/src/data_reader.h @@ -7,38 +7,27 @@ #include #include -#include "channel/channel.h" +#include "channel.h" #include "message/message_bundle.h" #include "message/priority_queue.h" -#include "reliability/barrier_helper.h" -#include "reliability_helper.h" #include "runtime_context.h" namespace ray { namespace streaming { -class ReliabilityHelper; -class AtLeastOnceHelper; - -enum class BundleCheckStatus : uint32_t { - OkBundle = 0, - BundleToBeThrown = 1, - BundleToBeSplit = 2 +/// Databundle is super-bundle that contains channel information (upstream +/// channel id & bundle meta data) and raw buffer pointer. +struct DataBundle { + uint8_t *data = nullptr; + uint32_t data_size; + ObjectID from; + uint64_t seq_id; + StreamingMessageBundleMetaPtr meta; }; -static inline std::ostream &operator<<(std::ostream &os, - const BundleCheckStatus &status) { - os << static_cast::type>(status); - return os; -} - /// This is implementation of merger policy in StreamingReaderMsgPtrComparator. struct StreamingReaderMsgPtrComparator { - explicit StreamingReaderMsgPtrComparator(ReliabilityLevel strategy) - : comp_strategy(strategy){}; - StreamingReaderMsgPtrComparator(){}; - ReliabilityLevel comp_strategy = ReliabilityLevel::EXACTLY_ONCE; - + StreamingReaderMsgPtrComparator() = default; bool operator()(const std::shared_ptr &a, const std::shared_ptr &b); }; @@ -61,8 +50,6 @@ class DataReader { std::shared_ptr last_fetched_queue_item_; - std::unordered_map global_barrier_cnt_; - int64_t timer_interval_; int64_t last_bundle_ts_; int64_t last_message_ts_; @@ -72,12 +59,6 @@ class DataReader { ObjectID last_read_q_id_; static const uint32_t kReadItemTimeout; - StreamingBarrierHelper barrier_helper_; - std::shared_ptr reliability_helper_; - std::unordered_map last_message_id_; - - friend class ReliabilityHelper; - friend class AtLeastOnceHelper; protected: std::unordered_map channel_info_map_; @@ -92,20 +73,15 @@ class DataReader { /// During initialization, only the channel parameters and necessary member properties /// are assigned. All channels will be connected in the first reading operation. /// \param input_ids - /// \param init_params + /// \param actor_ids + /// \param channel_seq_ids /// \param msg_ids - /// \param[out] creation_status /// \param timer_interval void Init(const std::vector &input_ids, const std::vector &init_params, - const std::vector &msg_ids, - std::vector &creation_status, int64_t timer_interval); + const std::vector &channel_seq_ids, + const std::vector &msg_ids, int64_t timer_interval); - /// Create reader use msg_id=0, this method is public only for test, and users - /// usuallly don't need it. - /// \param input_ids - /// \param init_params - /// \param timer_interval void Init(const std::vector &input_ids, const std::vector &init_params, int64_t timer_interval); @@ -132,30 +108,22 @@ class DataReader { private: /// Create channels and connect to all upstream. - StreamingStatus InitChannel(std::vector &creation_status); + StreamingStatus InitChannel(); /// One item from every channel will be popped out, then collecting /// them to a merged queue. High prioprity items will be fetched one by one. /// When item pop from one channel where must produce new item for placeholder /// in merged queue. - StreamingStatus InitChannelMerger(uint32_t timeout_ms); + StreamingStatus InitChannelMerger(); - StreamingStatus StashNextMessageAndPop(std::shared_ptr &message, - uint32_t timeout_ms); + StreamingStatus StashNextMessage(std::shared_ptr &message); StreamingStatus GetMessageFromChannel(ConsumerChannelInfo &channel_info, - std::shared_ptr &message, - uint32_t timeout_ms, uint32_t wait_time_ms); + std::shared_ptr &message); /// Get top item from prioprity queue. StreamingStatus GetMergedMessageBundle(std::shared_ptr &message, - bool &is_valid_break, uint32_t timeout_ms); - - bool BarrierAlign(std::shared_ptr &message); - - BundleCheckStatus CheckBundle(const std::shared_ptr &message); - - static void SplitBundle(std::shared_ptr &message, uint64_t last_msg_id); + bool &is_valid_break); }; } // namespace streaming } // namespace ray diff --git a/streaming/src/data_writer.cc b/streaming/src/data_writer.cc index ebb01b26c..733ee2a34 100644 --- a/streaming/src/data_writer.cc +++ b/streaming/src/data_writer.cc @@ -63,9 +63,7 @@ uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *dat uint32_t data_size, StreamingMessageType message_type) { STREAMING_LOG(DEBUG) << "WriteMessageToBufferRing q_id: " << q_id - << " data_size: " << data_size - << ", message_type=" << static_cast(message_type) - << ", data=" << Util::Byte2hex(data, data_size); + << " data_size: " << data_size; // TODO(lingxuan.zlx): currently, unsafe in multithreads ProducerChannelInfo &channel_info = channel_info_map_[q_id]; // Write message id stands for current lastest message id and differs from @@ -154,9 +152,6 @@ StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, flow_controller_ = std::make_shared(); break; } - - reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( - runtime_context_->GetConfig(), barrier_helper_, this, nullptr); // Register empty event and user event to event server. event_service_ = std::make_shared(); event_service_->Register( @@ -171,49 +166,6 @@ StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, return StreamingStatus::OK; } -void DataWriter::BroadcastBarrier(uint64_t barrier_id, const uint8_t *data, - uint32_t data_size) { - STREAMING_LOG(INFO) << "broadcast checkpoint id : " << barrier_id; - barrier_helper_.MapBarrierToCheckpoint(barrier_id, barrier_id); - - if (barrier_helper_.Contains(barrier_id)) { - STREAMING_LOG(WARNING) << "replicated global barrier id => " << barrier_id; - return; - } - - std::vector barrier_id_vec; - barrier_helper_.GetAllBarrier(barrier_id_vec); - if (barrier_id_vec.size() > 0) { - // Show all stashed barrier ids that means these checkpoint are not finished - // yet. - STREAMING_LOG(WARNING) << "[Writer] [Barrier] previous barrier(checkpoint) was fail " - "to do some opearting, ids => " - << Util::join(barrier_id_vec.begin(), barrier_id_vec.end(), - "|"); - } - StreamingBarrierHeader barrier_header = { - .barrier_type = StreamingBarrierType::GlobalBarrier, .barrier_id = barrier_id}; - - auto barrier_payload = - StreamingMessage::MakeBarrierPayload(barrier_header, data, data_size); - auto payload_size = kBarrierHeaderSize + data_size; - for (auto &queue_id : output_queue_ids_) { - uint64_t barrier_message_id = WriteMessageToBufferRing( - queue_id, barrier_payload.get(), payload_size, StreamingMessageType::Barrier); - if (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Interrupted) { - STREAMING_LOG(WARNING) << " stop right now"; - return; - } - - STREAMING_LOG(INFO) << "[Writer] [Barrier] write barrier to => " << queue_id - << ", barrier message id =>" << barrier_message_id - << ", barrier id => " << barrier_id; - } - - STREAMING_LOG(INFO) << "[Writer] [Barrier] global barrier id in runtime => " - << barrier_id; -} - DataWriter::DataWriter(std::shared_ptr &runtime_context) : transfer_config_(new Config()), runtime_context_(runtime_context) {} @@ -285,6 +237,7 @@ StreamingStatus DataWriter::WriteEmptyMessage(ProducerChannelInfo &channel_info) q_ringbuffer->FreeTransientBuffer(); RETURN_IF_NOT_OK(status) + channel_info.current_seq_id++; channel_info.message_pass_by_ts = current_time_ms(); return StreamingStatus::OK; } @@ -295,6 +248,7 @@ StreamingStatus DataWriter::WriteTransientBufferToChannel( StreamingStatus status = channel_map_[channel_info.channel_id]->ProduceItemToChannel( buffer_ptr->GetTransientBufferMutable(), buffer_ptr->GetTransientBufferSize()); RETURN_IF_NOT_OK(status) + channel_info.current_seq_id++; auto transient_bundle_meta = StreamingMessageBundleMeta::FromBytes(buffer_ptr->GetTransientBuffer()); bool is_barrier_bundle = transient_bundle_meta->IsBarrier(); @@ -313,23 +267,9 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, std::list message_list; uint32_t bundle_buffer_size = 0; const uint32_t max_queue_item_size = channel_info.queue_size; - - bool is_barrier = false; - - // Pop until one of the following condition meets: - // 1. ring buffer is empty - // 2. message count in bundle is larger than ring buffer size - // 3. sum of data size of messages in bundle is larger than streaming queue size - // 4. message type changed while (message_list.size() < runtime_context_->GetConfig().GetRingBufferCapacity() && !buffer_ptr->IsEmpty()) { StreamingMessagePtr &message_ptr = buffer_ptr->Front(); - STREAMING_LOG(DEBUG) << "Collecting message " << *message_ptr - << ", message_list_size=" << message_list.size() - << ", buffer capacity=" - << runtime_context_->GetConfig().GetRingBufferCapacity() - << ", buffer size=" << buffer_ptr->Size(); - uint32_t message_total_size = message_ptr->ClassBytesSize(); if (!message_list.empty() && bundle_buffer_size + message_total_size >= max_queue_item_size) { @@ -339,11 +279,6 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, } if (!message_list.empty() && message_list.back()->GetMessageType() != message_ptr->GetMessageType()) { - STREAMING_LOG(DEBUG) << "Different message type detected, break collecting, last " - "message type in list=" - << static_cast(message_list.back()->GetMessageType()) - << ", current collecing message type=" - << static_cast(message_ptr->GetMessageType()); break; } // ClassBytesSize = DataSize + MetaDataSize @@ -352,12 +287,6 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, message_list.push_back(message_ptr); buffer_ptr->Pop(); buffer_remain = buffer_ptr->Size(); - is_barrier = message_ptr->IsBarrier(); - STREAMING_LOG(DEBUG) << "Message " << *message_ptr - << " collected, message_list_size=" << message_list.size() - << ", buffer capacity=" - << runtime_context_->GetConfig().GetRingBufferCapacity() - << ", buffer size=" << buffer_ptr->Size(); } if (bundle_buffer_size >= channel_info.queue_size) { @@ -367,16 +296,9 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, } StreamingMessageBundlePtr bundle_ptr; - StreamingMessageBundleType bundleType = StreamingMessageBundleType::Bundle; - if (is_barrier) { - bundleType = StreamingMessageBundleType::Barrier; - } bundle_ptr = std::make_shared( - std::move(message_list), current_time_ms(), message_list.back()->GetMessageId(), - bundleType, bundle_buffer_size); - - STREAMING_LOG(DEBUG) << "CollectFromRingBuffer done, bundle=" << *bundle_ptr; - + std::move(message_list), current_time_ms(), message_list.back()->GetMessageSeqId(), + StreamingMessageBundleType::Bundle, bundle_buffer_size); buffer_ptr->ReallocTransientBuffer(bundle_ptr->ClassBytesSize()); bundle_ptr->ToBytes(buffer_ptr->GetTransientBufferMutable()); @@ -505,14 +427,14 @@ void DataWriter::RefreshChannelAndNotifyConsumed(ProducerChannelInfo &channel_in // Refresh current downstream consumed seq id. channel_map_[channel_info.channel_id]->RefreshChannelInfo(); // Notify the consumed information to local channel. - NotifyConsumedItem(channel_info, channel_info.queue_info.consumed_message_id); + NotifyConsumedItem(channel_info, channel_info.queue_info.consumed_seq_id); } void DataWriter::NotifyConsumedItem(ProducerChannelInfo &channel_info, uint32_t offset) { - if (offset > channel_info.current_message_id) { + if (offset > channel_info.current_seq_id) { STREAMING_LOG(WARNING) << "Can not notify consumed this offset " << offset << " that's out of range, max seq id " - << channel_info.current_message_id; + << channel_info.current_seq_id; } else { channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset); } @@ -550,58 +472,5 @@ void DataWriter::GetOffsetInfo( offset_map = &channel_info_map_; } -void DataWriter::ClearCheckpoint(uint64_t barrier_id) { - if (!barrier_helper_.Contains(barrier_id)) { - STREAMING_LOG(WARNING) << "no such barrier id => " << barrier_id; - return; - } - - std::string global_barrier_id_list_str = "|"; - - for (auto &queue_id : output_queue_ids_) { - uint64_t q_global_barrier_msg_id = 0; - StreamingStatus status = barrier_helper_.GetMsgIdByBarrierId(queue_id, barrier_id, - q_global_barrier_msg_id); - ProducerChannelInfo &channel_info = channel_info_map_[queue_id]; - if (status == StreamingStatus::OK) { - ClearCheckpointId(channel_info, q_global_barrier_msg_id); - } else { - STREAMING_LOG(WARNING) << "no seq record in q => " << queue_id << ", barrier id => " - << barrier_id; - } - global_barrier_id_list_str += - queue_id.Hex() + " : " + std::to_string(q_global_barrier_msg_id) + "| "; - reliability_helper_->CleanupCheckpoint(channel_info, barrier_id); - } - - STREAMING_LOG(INFO) - << "[Writer] [Barrier] [clear] global barrier flag, global barrier id => " - << barrier_id << ", seq id map => " << global_barrier_id_list_str; - - barrier_helper_.ReleaseBarrierMapById(barrier_id); - barrier_helper_.ReleaseBarrierMapCheckpointByBarrierId(barrier_id); -} - -void DataWriter::ClearCheckpointId(ProducerChannelInfo &channel_info, uint64_t msg_id) { - AutoSpinLock lock(notify_flag_); - - uint64_t current_msg_id = channel_info.current_message_id; - if (msg_id > current_msg_id) { - STREAMING_LOG(WARNING) << "current_msg_id=" << current_msg_id - << ", msg_id to be cleared=" << msg_id - << ", channel id = " << channel_info.channel_id; - } - channel_map_[channel_info.channel_id]->NotifyChannelConsumed(msg_id); - - STREAMING_LOG(DEBUG) << "clearing data from msg_id=" << msg_id - << ", qid= " << channel_info.channel_id; -} - -void DataWriter::GetChannelOffset(std::vector &result) { - for (auto &q_id : output_queue_ids_) { - result.push_back(channel_info_map_[q_id].current_message_id); - } -} - } // namespace streaming } // namespace ray diff --git a/streaming/src/data_writer.h b/streaming/src/data_writer.h index e2a18c334..43e3f7189 100644 --- a/streaming/src/data_writer.h +++ b/streaming/src/data_writer.h @@ -6,18 +6,15 @@ #include #include -#include "channel/channel.h" +#include "channel.h" #include "config/streaming_config.h" #include "event_service.h" #include "flow_control.h" #include "message/message_bundle.h" -#include "reliability/barrier_helper.h" -#include "reliability_helper.h" #include "runtime_context.h" namespace ray { namespace streaming { -class ReliabilityHelper; /// DataWriter is designed for data transporting between upstream and downstream. /// After the user sends the data, it does not immediately send the data to @@ -60,27 +57,6 @@ class DataWriter { const ObjectID &q_id, uint8_t *data, uint32_t data_size, StreamingMessageType message_type = StreamingMessageType::Message); - /// Send barrier to all channel. note there are user defined data in barrier bundle - /// \param barrier_id - /// \param data - /// \param data_size - /// - void BroadcastBarrier(uint64_t barrier_id, const uint8_t *data, uint32_t data_size); - - /// To relieve stress from large source/input data, we define a new function - /// clear_check_point - /// in producer/writer class. Worker can invoke this function if and only if - /// notify_consumed each item - /// flag is passed in reader/consumer, which means writer's producing became more - /// rhythmical and reader - /// can't walk on old way anymore. - /// \param barrier_id: user-defined numerical checkpoint id - void ClearCheckpoint(uint64_t barrier_id); - - /// Replay all queue from checkpoint, it's useful under FO - /// \param result offset vector - void GetChannelOffset(std::vector &result); - void Run(); void Stop(); @@ -136,8 +112,6 @@ class DataWriter { void FlowControlTimer(); - void ClearCheckpointId(ProducerChannelInfo &channel_info, uint64_t seq_id); - private: std::shared_ptr event_service_; @@ -150,15 +124,6 @@ class DataWriter { // unnecessary overflow. std::shared_ptr flow_controller_; - StreamingBarrierHelper barrier_helper_; - std::shared_ptr reliability_helper_; - - // Make thread-safe between loop thread and user thread. - // High-level runtime send notification about clear checkpoint if global - // checkpoint is finished and low-level will auto flush & evict item memory - // when no more space is available. - std::atomic_flag notify_flag_ = ATOMIC_FLAG_INIT; - protected: std::unordered_map channel_info_map_; /// ProducerChannel is middle broker for data transporting and all downstream diff --git a/streaming/src/event_service.cc b/streaming/src/event_service.cc index dc1003f9e..926a4039b 100644 --- a/streaming/src/event_service.cc +++ b/streaming/src/event_service.cc @@ -57,7 +57,6 @@ void EventQueue::Pop() { no_full_cv_.notify_all(); } -constexpr int EventQueue::kConditionTimeoutMs; void EventQueue::WaitFor(std::unique_lock &lock) { // To avoid deadlock when EventQueue is empty but is_active is changed in other // thread, Event queue should awaken this condtion variable and check it again. diff --git a/streaming/src/event_service.h b/streaming/src/event_service.h index adbba3111..298443800 100644 --- a/streaming/src/event_service.h +++ b/streaming/src/event_service.h @@ -6,8 +6,9 @@ #include #include -#include "channel/channel.h" -#include "ring_buffer/ring_buffer.h" +#include "channel.h" +#include "ray/core_worker/core_worker.h" +#include "ring_buffer.h" #include "util/streaming_util.h" namespace ray { diff --git a/streaming/src/flow_control.cc b/streaming/src/flow_control.cc index b49d10c81..77cde8128 100644 --- a/streaming/src/flow_control.cc +++ b/streaming/src/flow_control.cc @@ -10,21 +10,21 @@ UnconsumedSeqFlowControl::UnconsumedSeqFlowControl( bool UnconsumedSeqFlowControl::ShouldFlowControl(ProducerChannelInfo &channel_info) { auto &queue_info = channel_info.queue_info; - if (queue_info.target_message_id <= channel_info.current_message_id) { + if (queue_info.target_seq_id <= channel_info.current_seq_id) { channel_map_[channel_info.channel_id]->RefreshChannelInfo(); // Target seq id is maximum upper limit in current condition. - channel_info.queue_info.target_message_id = - channel_info.queue_info.consumed_message_id + consumed_step_; - STREAMING_LOG(DEBUG) - << "Flow control stop writing to downstream, current message id => " - << channel_info.current_message_id << ", target message id => " - << queue_info.target_message_id << ", consumed_id => " - << queue_info.consumed_message_id << ", q id => " << channel_info.channel_id - << ". if this log keeps printing, it means something wrong " - "with queue's info API, or downstream node is not " - "consuming data."; + channel_info.queue_info.target_seq_id = + channel_info.queue_info.consumed_seq_id + consumed_step_; + STREAMING_LOG(DEBUG) << "Flow control stop writing to downstream, current max id => " + << channel_info.current_seq_id << ", target seq id => " + << queue_info.target_seq_id << ", consumed_id => " + << queue_info.consumed_seq_id << ", q id => " + << channel_info.channel_id + << ". if this log keeps printing, it means something wrong " + "with queue's info API, or downstream node is not " + "consuming data."; // Double check after refreshing if target seq id is changed. - if (queue_info.target_message_id <= channel_info.current_message_id) { + if (queue_info.target_seq_id <= channel_info.current_seq_id) { return true; } } diff --git a/streaming/src/flow_control.h b/streaming/src/flow_control.h index 005e75b8d..0fdcb3291 100644 --- a/streaming/src/flow_control.h +++ b/streaming/src/flow_control.h @@ -1,6 +1,6 @@ #pragma once -#include "channel/channel.h" +#include "channel.h" namespace ray { namespace streaming { diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.cc similarity index 64% rename from streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc rename to streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.cc index 241db2b45..891d95e81 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.cc @@ -1,18 +1,17 @@ -#include "io_ray_streaming_runtime_transfer_channel_ChannelId.h" +#include "io_ray_streaming_runtime_transfer_ChannelId.h" + #include "streaming_jni_common.h" using namespace ray::streaming; -JNIEXPORT jlong JNICALL -Java_io_ray_streaming_runtime_transfer_channel_ChannelId_createNativeId( +JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_ChannelId_createNativeId( JNIEnv *env, jclass cls, jlong qid_address) { auto id = ray::ObjectID::FromBinary( std::string(reinterpret_cast(qid_address), ray::ObjectID::Size())); return reinterpret_cast(new ray::ObjectID(id)); } -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_channel_ChannelId_destroyNativeId( +JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_ChannelId_destroyNativeId( JNIEnv *env, jclass cls, jlong native_id_ptr) { auto id = reinterpret_cast(native_id_ptr); STREAMING_CHECK(id != nullptr); diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.h new file mode 100644 index 000000000..839617026 --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.h @@ -0,0 +1,31 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class io_ray_streaming_runtime_transfer_ChannelId */ + +#ifndef _Included_io_ray_streaming_runtime_transfer_ChannelId +#define _Included_io_ray_streaming_runtime_transfer_ChannelId +#ifdef __cplusplus +extern "C" { +#endif +#undef io_ray_streaming_runtime_transfer_ChannelId_ID_LENGTH +#define io_ray_streaming_runtime_transfer_ChannelId_ID_LENGTH 20L +/* + * Class: io_ray_streaming_runtime_transfer_ChannelId + * Method: createNativeId + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_ChannelId_createNativeId(JNIEnv *, jclass, jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_ChannelId + * Method: destroyNativeId + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_ChannelId_destroyNativeId(JNIEnv *, jclass, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc index ba4a4d405..bcbe68d39 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc @@ -12,13 +12,15 @@ using namespace ray::streaming; JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( JNIEnv *env, jclass, jobject streaming_queue_initial_parameters, - jobjectArray input_channels, jlongArray msg_id_array, jlong timer_interval, - jobject creation_status, jbyteArray config_bytes, jboolean is_mock) { + jobjectArray input_channels, jlongArray seq_id_array, jlongArray msg_id_array, + jlong timer_interval, jboolean isRecreate, jbyteArray config_bytes, + jboolean is_mock) { STREAMING_LOG(INFO) << "[JNI]: create DataReader."; std::vector parameter_vec; ParseChannelInitParameters(env, streaming_queue_initial_parameters, parameter_vec); std::vector input_channels_ids = jarray_to_object_id_vec(env, input_channels); + std::vector seq_ids = LongVectorFromJLongArray(env, seq_id_array).data; std::vector msg_ids = LongVectorFromJLongArray(env, msg_id_array).data; auto ctx = std::make_shared(); @@ -30,24 +32,8 @@ Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( if (is_mock) { ctx->MarkMockTest(); } - - // init reader auto reader = new DataReader(ctx); - std::vector creation_status_vec; - reader->Init(input_channels_ids, parameter_vec, msg_ids, creation_status_vec, - timer_interval); - - // add creation status to Java's List - jclass array_list_cls = env->GetObjectClass(creation_status); - jclass integer_cls = env->FindClass("java/lang/Integer"); - jmethodID array_list_add = - env->GetMethodID(array_list_cls, "add", "(Ljava/lang/Object;)Z"); - for (auto &status : creation_status_vec) { - jmethodID integer_init = env->GetMethodID(integer_cls, "", "(I)V"); - jobject integer_obj = - env->NewObject(integer_cls, integer_init, static_cast(status)); - env->CallBooleanMethod(creation_status, array_list_add, integer_obj); - } + reader->Init(input_channels_ids, parameter_vec, seq_ids, msg_ids, timer_interval); STREAMING_LOG(INFO) << "create native DataReader succeed"; return reinterpret_cast(reader); } @@ -65,6 +51,8 @@ JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBund } else if (StreamingStatus::GetBundleTimeOut == status) { } else if (StreamingStatus::InitQueueFailed == status) { throwRuntimeException(env, "init channel failed"); + } else if (StreamingStatus::WaitQueueTimeOut == status) { + throwRuntimeException(env, "wait channel object timeout"); } if (StreamingStatus::OK != status) { @@ -100,34 +88,3 @@ Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env, jlong ptr) { delete reinterpret_cast(ptr); } - -JNIEXPORT jbyteArray JNICALL -Java_io_ray_streaming_runtime_transfer_DataReader_getOffsetsInfoNative(JNIEnv *env, - jobject thisObj, - jlong ptr) { - auto reader = reinterpret_cast(ptr); - std::unordered_map *offset_map = nullptr; - reader->GetOffsetInfo(offset_map); - STREAMING_CHECK(offset_map); - // queue nums + (queue id + seq id + message id) * queue nums - int offset_data_size = - sizeof(uint32_t) + (kUniqueIDSize + sizeof(uint64_t) * 2) * offset_map->size(); - jbyteArray offsets_info = env->NewByteArray(offset_data_size); - int offset = 0; - // total queue nums - auto queue_nums = static_cast(offset_map->size()); - env->SetByteArrayRegion(offsets_info, offset, sizeof(uint32_t), - reinterpret_cast(&queue_nums)); - offset += sizeof(uint32_t); - // queue name & offset - for (auto &p : *offset_map) { - env->SetByteArrayRegion(offsets_info, offset, kUniqueIDSize, - reinterpret_cast(p.first.Data())); - offset += kUniqueIDSize; - // msg_id - env->SetByteArrayRegion(offsets_info, offset, sizeof(uint64_t), - reinterpret_cast(&p.second.current_message_id)); - offset += sizeof(uint64_t); - } - return offsets_info; -} \ No newline at end of file diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h index 43f677d34..221fe3105 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h @@ -1,17 +1,3 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - /* DO NOT EDIT THIS FILE - it is machine generated */ #include /* Header for class io_ray_streaming_runtime_transfer_DataReader */ @@ -25,12 +11,12 @@ extern "C" { * Class: io_ray_streaming_runtime_transfer_DataReader * Method: createDataReaderNative * Signature: - * (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[JJLjava/util/List;[BZ)J + * (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[J[JJZ[BZ)J */ JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( - JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlong, jobject, jbyteArray, - jboolean); + JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlongArray, jlong, jboolean, + jbyteArray, jboolean); /* * Class: io_ray_streaming_runtime_transfer_DataReader @@ -40,15 +26,6 @@ Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative( JNIEnv *, jobject, jlong, jlong, jlong, jlong); -/* - * Class: io_ray_streaming_runtime_transfer_DataReader - * Method: getOffsetsInfoNative - * Signature: (J)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_io_ray_streaming_runtime_transfer_DataReader_getOffsetsInfoNative(JNIEnv *, jobject, - jlong); - /* * Class: io_ray_streaming_runtime_transfer_DataReader * Method: stopReaderNative diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc index efccdf98e..4c2fa8422 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc @@ -79,40 +79,4 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env, jlong ptr) { auto *data_writer = reinterpret_cast(ptr); delete data_writer; -} - -JNIEXPORT jlongArray JNICALL -Java_io_ray_streaming_runtime_transfer_DataWriter_getOutputMsgIdNative(JNIEnv *env, - jobject thisObj, - jlong ptr) { - DataWriter *writer_client = reinterpret_cast(ptr); - - std::vector result; - writer_client->GetChannelOffset(result); - - jlongArray jArray = env->NewLongArray(result.size()); - jlong jdata[result.size()]; - for (size_t i = 0; i < result.size(); ++i) { - *(jdata + i) = result[i]; - } - env->SetLongArrayRegion(jArray, 0, result.size(), jdata); - return jArray; -} - -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_DataWriter_broadcastBarrierNative( - JNIEnv *env, jobject thisObj, jlong ptr, jlong checkpointId, jbyteArray data) { - STREAMING_LOG(INFO) << "jni: broadcast barrier, cp_id=" << checkpointId; - RawDataFromJByteArray raw_data(env, data); - DataWriter *writer_client = reinterpret_cast(ptr); - writer_client->BroadcastBarrier(checkpointId, raw_data.data, raw_data.data_size); -} - -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_DataWriter_clearCheckpointNative( - JNIEnv *env, jobject thisObj, jlong ptr, jlong checkpointId) { - STREAMING_LOG(INFO) << "[Producer] jni: clearCheckpoints."; - auto *writer = reinterpret_cast(ptr); - writer->ClearCheckpoint(checkpointId); - STREAMING_LOG(INFO) << "[Producer] clear checkpoint done."; -} +} \ No newline at end of file diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h index ff6ebb839..d445638b9 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h @@ -1,17 +1,3 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - /* DO NOT EDIT THIS FILE - it is machine generated */ #include /* Header for class io_ray_streaming_runtime_transfer_DataWriter */ @@ -58,35 +44,6 @@ JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *, jobject, jlong); -/* - * Class: io_ray_streaming_runtime_transfer_DataWriter - * Method: getOutputMsgIdNative - * Signature: (J)[J - */ -JNIEXPORT jlongArray JNICALL -Java_io_ray_streaming_runtime_transfer_DataWriter_getOutputMsgIdNative(JNIEnv *, jobject, - jlong); - -/* - * Class: io_ray_streaming_runtime_transfer_DataWriter - * Method: broadcastBarrierNative - * Signature: (JJJ[B)V - */ -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_DataWriter_broadcastBarrierNative(JNIEnv *, - jobject, jlong, - jlong, - jbyteArray); - -/* - * Class: io_ray_streaming_runtime_transfer_DataWriter - * Method: clearCheckpointNative - * Signature: (JJ)V - */ -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_DataWriter_clearCheckpointNative(JNIEnv *, jobject, - jlong, jlong); - #ifdef __cplusplus } #endif diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h index 4e5c826f5..24517d7c4 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h @@ -1,17 +1,3 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - /* DO NOT EDIT THIS FILE - it is machine generated */ #include /* Header for class io_ray_streaming_runtime_transfer_TransferHandler */ @@ -24,7 +10,7 @@ extern "C" { /* * Class: io_ray_streaming_runtime_transfer_TransferHandler * Method: createWriterClientNative - * Signature: ()J + * Signature: (J)J */ JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *, @@ -33,7 +19,7 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative( /* * Class: io_ray_streaming_runtime_transfer_TransferHandler * Method: createReaderClientNative - * Signature: ()J + * Signature: (J)J */ JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *, diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h deleted file mode 100644 index ab1295afd..000000000 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class io_ray_streaming_runtime_transfer_channel_ChannelId */ - -#ifndef _Included_io_ray_streaming_runtime_transfer_channel_ChannelId -#define _Included_io_ray_streaming_runtime_transfer_channel_ChannelId -#ifdef __cplusplus -extern "C" { -#endif -#undef io_ray_streaming_runtime_transfer_channel_ChannelId_ID_LENGTH -#define io_ray_streaming_runtime_transfer_channel_ChannelId_ID_LENGTH 20L -/* - * Class: io_ray_streaming_runtime_transfer_channel_ChannelId - * Method: createNativeId - * Signature: (J)J - */ -JNIEXPORT jlong JNICALL -Java_io_ray_streaming_runtime_transfer_channel_ChannelId_createNativeId(JNIEnv *, jclass, - jlong); - -/* - * Class: io_ray_streaming_runtime_transfer_channel_ChannelId - * Method: destroyNativeId - * Signature: (J)V - */ -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_channel_ChannelId_destroyNativeId(JNIEnv *, jclass, - jlong); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/streaming/src/lib/java/streaming_jni_common.h b/streaming/src/lib/java/streaming_jni_common.h index 1b70dbfc7..996d58d21 100644 --- a/streaming/src/lib/java/streaming_jni_common.h +++ b/streaming/src/lib/java/streaming_jni_common.h @@ -4,7 +4,7 @@ #include -#include "channel/channel.h" +#include "channel.h" #include "ray/core_worker/common.h" #include "util/streaming_logging.h" diff --git a/streaming/src/message/message.cc b/streaming/src/message/message.cc index 3216e3b3a..cd44f76ee 100644 --- a/streaming/src/message/message.cc +++ b/streaming/src/message/message.cc @@ -10,32 +10,30 @@ namespace ray { namespace streaming { -StreamingMessage::StreamingMessage(std::shared_ptr &payload_data, - uint32_t payload_size, uint64_t msg_id, - StreamingMessageType message_type) - : payload_(payload_data), - payload_size_(payload_size), +StreamingMessage::StreamingMessage(std::shared_ptr &data, uint32_t data_size, + uint64_t seq_id, StreamingMessageType message_type) + : message_data_(data), + data_size_(data_size), message_type_(message_type), - message_id_(msg_id) {} + message_id_(seq_id) {} -StreamingMessage::StreamingMessage(std::shared_ptr &&payload_data, - uint32_t payload_size, uint64_t msg_id, - StreamingMessageType message_type) - : payload_(payload_data), - payload_size_(payload_size), +StreamingMessage::StreamingMessage(std::shared_ptr &&data, uint32_t data_size, + uint64_t seq_id, StreamingMessageType message_type) + : message_data_(data), + data_size_(data_size), message_type_(message_type), - message_id_(msg_id) {} + message_id_(seq_id) {} -StreamingMessage::StreamingMessage(const uint8_t *payload_data, uint32_t payload_size, - uint64_t msg_id, StreamingMessageType message_type) - : payload_size_(payload_size), message_type_(message_type), message_id_(msg_id) { - payload_.reset(new uint8_t[payload_size], std::default_delete()); - std::memcpy(payload_.get(), payload_data, payload_size); +StreamingMessage::StreamingMessage(const uint8_t *data, uint32_t data_size, + uint64_t seq_id, StreamingMessageType message_type) + : data_size_(data_size), message_type_(message_type), message_id_(seq_id) { + message_data_.reset(new uint8_t[data_size], std::default_delete()); + std::memcpy(message_data_.get(), data, data_size_); } StreamingMessage::StreamingMessage(const StreamingMessage &msg) { - payload_size_ = msg.payload_size_; - payload_ = msg.payload_; + data_size_ = msg.data_size_; + message_data_ = msg.message_data_; message_id_ = msg.message_id_; message_type_ = msg.message_type_; } @@ -46,8 +44,8 @@ StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes, uint32_t data_size = *reinterpret_cast(bytes + byte_offset); byte_offset += sizeof(data_size); - uint64_t msg_id = *reinterpret_cast(bytes + byte_offset); - byte_offset += sizeof(msg_id); + uint64_t seq_id = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(seq_id); StreamingMessageType msg_type = *reinterpret_cast(bytes + byte_offset); @@ -56,14 +54,14 @@ StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes, auto buf = new uint8_t[data_size]; std::memcpy(buf, bytes + byte_offset, data_size); auto data_ptr = std::shared_ptr(buf, std::default_delete()); - return std::make_shared(data_ptr, data_size, msg_id, msg_type); + return std::make_shared(data_ptr, data_size, seq_id, msg_type); } void StreamingMessage::ToBytes(uint8_t *serlizable_data) { uint32_t byte_offset = 0; - std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&payload_size_), - sizeof(payload_size_)); - byte_offset += sizeof(payload_size_); + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&data_size_), + sizeof(data_size_)); + byte_offset += sizeof(data_size_); std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&message_id_), sizeof(message_id_)); @@ -73,28 +71,19 @@ void StreamingMessage::ToBytes(uint8_t *serlizable_data) { sizeof(message_type_)); byte_offset += sizeof(message_type_); - std::memcpy(serlizable_data + byte_offset, reinterpret_cast(payload_.get()), - payload_size_); + std::memcpy(serlizable_data + byte_offset, + reinterpret_cast(message_data_.get()), data_size_); - byte_offset += payload_size_; + byte_offset += data_size_; STREAMING_CHECK(byte_offset == this->ClassBytesSize()); } bool StreamingMessage::operator==(const StreamingMessage &message) const { - return PayloadSize() == message.PayloadSize() && - GetMessageId() == message.GetMessageId() && + return GetDataSize() == message.GetDataSize() && + GetMessageSeqId() == message.GetMessageSeqId() && GetMessageType() == message.GetMessageType() && - !std::memcmp(Payload(), message.Payload(), PayloadSize()); -} - -std::ostream &operator<<(std::ostream &os, const StreamingMessage &message) { - os << "{" - << " message_type_: " << static_cast(message.GetMessageType()) - << " message_id_: " << message.GetMessageId() - << " payload_size_: " << message.payload_size_ - << " payload_: " << (void *)message.payload_.get() << "}"; - return os; + !std::memcmp(RawData(), message.RawData(), data_size_); } } // namespace streaming diff --git a/streaming/src/message/message.h b/streaming/src/message/message.h index 4aa7c35b4..66a78e9f4 100644 --- a/streaming/src/message/message.h +++ b/streaming/src/message/message.h @@ -1,6 +1,5 @@ #pragma once -#include #include namespace ray { @@ -17,75 +16,52 @@ enum class StreamingMessageType : uint32_t { MAX = Message }; -enum class StreamingBarrierType : uint32_t { GlobalBarrier = 0 }; - -struct StreamingBarrierHeader { - StreamingBarrierType barrier_type; - uint64_t barrier_id; - inline bool IsGlobalBarrier() { - return StreamingBarrierType::GlobalBarrier == barrier_type; - } -}; - constexpr uint32_t kMessageHeaderSize = sizeof(uint32_t) + sizeof(uint64_t) + sizeof(StreamingMessageType); -constexpr uint32_t kBarrierHeaderSize = sizeof(StreamingBarrierType) + sizeof(uint64_t); - /// All messages should be wrapped by this protocol. // DataSize means length of raw data, message id is increasing from [1, +INF]. // MessageType will be used for barrier transporting and checkpoint. /// +----------------+ -/// | PayloadSize=U32| +/// | DataSize=U32 | /// +----------------+ /// | MessageId=U64 | /// +----------------+ /// | MessageType=U32| /// +----------------+ -/// | Payload=var | +/// | Data=var | /// +----------------+ -/// Payload field contains barrier header and carried buffer if message type is -/// global/partial barrier. -/// -/// Barrier's Payload field: -/// +----------------------------+ -/// | StreamingBarrierType=U32 | -/// +----------------------------+ -/// | barrier_id=U64 | -/// +----------------------------+ -/// | carried_buffer=var | -/// +----------------------------+ class StreamingMessage { private: - std::shared_ptr payload_; - uint32_t payload_size_; + std::shared_ptr message_data_; + uint32_t data_size_; StreamingMessageType message_type_; uint64_t message_id_; public: /// Copy raw data from outside shared buffer. - /// \param payload_ raw data from user buffer - /// \param payload_size_ raw data size - /// \param msg_id message id + /// \param data raw data from user buffer + /// \param data_size raw data size + /// \param seq_id message id /// \param message_type - StreamingMessage(std::shared_ptr &payload_data, uint32_t payload_size, - uint64_t msg_id, StreamingMessageType message_type); + StreamingMessage(std::shared_ptr &data, uint32_t data_size, uint64_t seq_id, + StreamingMessageType message_type); /// Move outsite raw data to message data. - /// \param payload_ raw data from user buffer - /// \param payload_size_ raw data size - /// \param msg_id message id + /// \param data raw data from user buffer + /// \param data_size raw data size + /// \param seq_id message id /// \param message_type - StreamingMessage(std::shared_ptr &&payload_data, uint32_t payload_size, - uint64_t msg_id, StreamingMessageType message_type); + StreamingMessage(std::shared_ptr &&data, uint32_t data_size, uint64_t seq_id, + StreamingMessageType message_type); /// Copy raw data from outside buffer. - /// \param payload_ raw data from user buffer - /// \param payload_size_ raw data size - /// \param msg_id message id + /// \param data raw data from user buffer + /// \param data_size raw data size + /// \param seq_id message id /// \param message_type - StreamingMessage(const uint8_t *payload_data, uint32_t payload_size, uint64_t msg_id, + StreamingMessage(const uint8_t *data, uint32_t data_size, uint64_t seq_id, StreamingMessageType message_type); StreamingMessage(const StreamingMessage &); @@ -94,44 +70,20 @@ class StreamingMessage { virtual ~StreamingMessage() = default; + inline uint8_t *RawData() const { return message_data_.get(); } + + inline uint32_t GetDataSize() const { return data_size_; } inline StreamingMessageType GetMessageType() const { return message_type_; } - inline uint64_t GetMessageId() const { return message_id_; } - - inline uint8_t *Payload() const { return payload_.get(); } - - inline uint32_t PayloadSize() const { return payload_size_; } - + inline uint64_t GetMessageSeqId() const { return message_id_; } inline bool IsMessage() { return StreamingMessageType::Message == message_type_; } inline bool IsBarrier() { return StreamingMessageType::Barrier == message_type_; } bool operator==(const StreamingMessage &) const; - static inline std::shared_ptr MakeBarrierPayload( - StreamingBarrierHeader &barrier_header, const uint8_t *data, uint32_t data_size) { - std::shared_ptr ptr(new uint8_t[data_size + kBarrierHeaderSize], - std::default_delete()); - std::memcpy(ptr.get(), &barrier_header.barrier_type, sizeof(StreamingBarrierType)); - std::memcpy(ptr.get() + sizeof(StreamingBarrierType), &barrier_header.barrier_id, - sizeof(uint64_t)); - if (data && data_size > 0) { - std::memcpy(ptr.get() + kBarrierHeaderSize, data, data_size); - } - return ptr; - } - virtual void ToBytes(uint8_t *data); static StreamingMessagePtr FromBytes(const uint8_t *data, bool verifer_check = true); - inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + payload_size_; } - - static inline void GetBarrierIdFromRawData(const uint8_t *data, - StreamingBarrierHeader *barrier_header) { - barrier_header->barrier_type = *reinterpret_cast(data); - barrier_header->barrier_id = - *reinterpret_cast(data + sizeof(StreamingBarrierType)); - } - - friend std::ostream &operator<<(std::ostream &os, const StreamingMessage &message); + inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + data_size_; } }; } // namespace streaming diff --git a/streaming/src/message/message_bundle.cc b/streaming/src/message/message_bundle.cc index 629a0613d..13057c428 100644 --- a/streaming/src/message/message_bundle.cc +++ b/streaming/src/message/message_bundle.cc @@ -63,14 +63,6 @@ bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta *meta) co return operator==(*meta); } -std::ostream &operator<<(std::ostream &os, const StreamingMessageBundleMeta &meta) { - os << "{" - << "last_message_id_: " << meta.last_message_id_ - << ", message_list_size_: " << meta.message_list_size_ - << ", bundle_type_: " << static_cast(meta.bundle_type_) << "}"; - return os; -} - StreamingMessageBundleMeta::StreamingMessageBundleMeta() : bundle_type_(StreamingMessageBundleType::Empty) {} @@ -196,13 +188,5 @@ bool StreamingMessageBundle::operator==(StreamingMessageBundle &bundle) const { bool StreamingMessageBundle::operator==(StreamingMessageBundle *bundle) const { return this->operator==(*bundle); } - -std::ostream &operator<<(std::ostream &os, const DataBundle &bundle) { - os << "{" - << "data: " << (void *)bundle.data << ", data_size: " << bundle.data_size - << ", channel last_barrier_id: " << bundle.last_barrier_id - << ", meta: " << *(bundle.meta) << "}"; - return os; -} } // namespace streaming } // namespace ray diff --git a/streaming/src/message/message_bundle.h b/streaming/src/message/message_bundle.h index 2cad05e36..a5f8687ca 100644 --- a/streaming/src/message/message_bundle.h +++ b/streaming/src/message/message_bundle.h @@ -7,7 +7,6 @@ #include #include "message/message.h" -#include "ray/common/id.h" namespace ray { namespace streaming { @@ -84,7 +83,6 @@ class StreamingMessageBundleMeta { inline bool IsBarrier() { return StreamingMessageBundleType::Barrier == bundle_type_; } inline bool IsBundle() { return StreamingMessageBundleType::Bundle == bundle_type_; } - inline bool IsEmptyMsg() { return StreamingMessageBundleType::Empty == bundle_type_; } virtual void ToBytes(uint8_t *data); static StreamingMessageBundleMetaPtr FromBytes(const uint8_t *data, @@ -101,9 +99,6 @@ class StreamingMessageBundleMeta { "," + std::to_string(message_bundle_ts_) + "," + std::to_string(static_cast(bundle_type_)); } - - friend std::ostream &operator<<(std::ostream &os, - const StreamingMessageBundleMeta &meta); }; /// StreamingMessageBundle inherits from metadata class (StreamingMessageBundleMeta) @@ -182,30 +177,5 @@ class StreamingMessageBundle : public StreamingMessageBundleMeta { const std::list &message_list, uint32_t raw_data_size, uint8_t *raw_data); }; - -/// Databundle is super-bundle that contains channel information (upstream -/// channel id & bundle meta data) and raw buffer pointer. -struct DataBundle { - uint8_t *data = nullptr; - uint32_t data_size; - ObjectID from; - uint32_t last_barrier_id; - StreamingMessageBundleMetaPtr meta; - bool is_reallocated = false; - - ~DataBundle() { - if (is_reallocated) { - delete[] data; - } - } - - void Realloc(uint32_t size) { - data = new uint8_t[size]; - is_reallocated = true; - } - - friend std::ostream &operator<<(std::ostream &os, const DataBundle &bundle); -}; - } // namespace streaming } // namespace ray diff --git a/streaming/src/protobuf/remote_call.proto b/streaming/src/protobuf/remote_call.proto index 34c2dac7b..5e9e2a754 100644 --- a/streaming/src/protobuf/remote_call.proto +++ b/streaming/src/protobuf/remote_call.proto @@ -4,8 +4,6 @@ package ray.streaming.proto; import "protobuf/streaming.proto"; -import "google/protobuf/any.proto"; - option java_package = "io.ray.streaming.runtime.generated"; // Execution vertex info, including it's upstream and downstream @@ -24,7 +22,7 @@ message ExecutionVertexContext { // unique id of execution vertex int32 execution_vertex_id = 1; // unique id of execution job vertex - int32 execution_job_vertex_id = 2; + int32 execution_job_vertex_Id = 2; // name of execution job vertex, e.g. 1-SourceOperator string execution_job_vertex_name = 3; // index of execution vertex @@ -58,48 +56,3 @@ message PythonJobWorkerContext { // vertex including it's upstream and downstream ExecutionVertexContext execution_vertex_context = 2; } - -message BoolResult { - bool boolRes = 1; -} - -message Barrier { - int64 id = 1; -} - -message CheckpointId { - int64 checkpoint_id = 1; -} - -message BaseWorkerCmd { - bytes actor_id = 1; // actor id - int64 timestamp = 2; - google.protobuf.Any detail = 3; -} - -message WorkerCommitReport { - int64 commit_checkpoint_id = 1; -} - -message WorkerRollbackRequest { - string exception_msg = 1; - string worker_hostname = 2; - string worker_pid = 3; -} - -message CallResult { - bool success = 1; - int32 result_code = 2; - string result_msg = 3; - QueueRecoverInfo result_obj = 4; -} - -message QueueRecoverInfo { - enum QueueCreationStatus { - FreshStarted = 0; - PullOk = 1; - Timeout = 2; - DataLost = 3; - } - map creation_status = 3; -} \ No newline at end of file diff --git a/streaming/src/protobuf/streaming.proto b/streaming/src/protobuf/streaming.proto index 9f26c20ef..e79143556 100644 --- a/streaming/src/protobuf/streaming.proto +++ b/streaming/src/protobuf/streaming.proto @@ -2,8 +2,6 @@ syntax = "proto3"; package ray.streaming.proto; -import "google/protobuf/any.proto"; - option java_package = "io.ray.streaming.runtime.generated"; enum Language { @@ -22,12 +20,6 @@ enum NodeType { SINK = 3; } -enum ReliabilityLevel { - NONE = 0; - AT_LEAST_ONCE = 1; - EXACTLY_ONCE = 2; -} - enum FlowControlType { UNKNOWN_FLOW_CONTROL_TYPE = 0; UnconsumedSeqFlowControl = 1; diff --git a/streaming/src/queue/message.cc b/streaming/src/queue/message.cc index b0395a806..ff1edad2a 100644 --- a/streaming/src/queue/message.cc +++ b/streaming/src/queue/message.cc @@ -90,7 +90,7 @@ std::shared_ptr DataMessage::FromBytes(uint8_t *bytes) { void NotificationMessage::ToProtobuf(std::string *output) { queue::protobuf::StreamingQueueNotificationMsg msg; FillMessageCommon(msg.mutable_common()); - msg.set_seq_id(msg_id_); + msg.set_seq_id(seq_id_); msg.SerializeToString(output); } diff --git a/streaming/src/queue/message.h b/streaming/src/queue/message.h index 9438a4714..42b474af2 100644 --- a/streaming/src/queue/message.h +++ b/streaming/src/queue/message.h @@ -102,19 +102,19 @@ class DataMessage : public Message { class NotificationMessage : public Message { public: NotificationMessage(const ActorID &actor_id, const ActorID &peer_actor_id, - const ObjectID &queue_id, uint64_t msg_id) - : Message(actor_id, peer_actor_id, queue_id), msg_id_(msg_id) {} + const ObjectID &queue_id, uint64_t seq_id) + : Message(actor_id, peer_actor_id, queue_id), seq_id_(seq_id) {} virtual ~NotificationMessage() {} static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - inline uint64_t MsgId() { return msg_id_; } + inline uint64_t SeqId() { return seq_id_; } inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } private: - uint64_t msg_id_; + uint64_t seq_id_; const queue::protobuf::StreamingQueueMessageType type_ = queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType; }; diff --git a/streaming/src/queue/queue.cc b/streaming/src/queue/queue.cc index 1885a460e..434420ca3 100644 --- a/streaming/src/queue/queue.cc +++ b/streaming/src/queue/queue.cc @@ -101,8 +101,9 @@ size_t Queue::PendingCount() { return begin->SeqId() - end->SeqId() + 1; } -Status WriterQueue::Push(uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, - uint64_t msg_id_start, uint64_t msg_id_end, bool raw) { +Status WriterQueue::Push(uint64_t seq_id, uint8_t *buffer, uint32_t buffer_size, + uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end, + bool raw) { if (IsPendingFull(buffer_size)) { return Status::OutOfMemory("Queue Push OutOfMemory"); } @@ -112,9 +113,9 @@ Status WriterQueue::Push(uint8_t *buffer, uint32_t buffer_size, uint64_t timesta std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - QueueItem item(seq_id_, buffer, buffer_size, timestamp, msg_id_start, msg_id_end, raw); + QueueItem item(seq_id, buffer, buffer_size, timestamp, msg_id_start, msg_id_end, raw); Queue::Push(item); - STREAMING_LOG(DEBUG) << "WriterQueue::Push seq_id: " << seq_id_; + STREAMING_LOG(DEBUG) << "WriterQueue::Push seq_id_: " << seq_id_; seq_id_++; return Status::OK(); } @@ -131,41 +132,33 @@ void WriterQueue::Send() { } Status WriterQueue::TryEvictItems() { + STREAMING_LOG(INFO) << "TryEvictItems"; QueueItem item = FrontProcessed(); - STREAMING_LOG(DEBUG) << "TryEvictItems queue_id: " << queue_id_ << " first_item: (" - << item.MsgIdStart() << "," << item.MsgIdEnd() << ")" - << " min_consumed_msg_id_: " << min_consumed_msg_id_ - << " eviction_limit_: " << eviction_limit_ - << " max_data_size_: " << max_data_size_ - << " data_size_sent_: " << data_size_sent_ - << " data_size_: " << data_size_; - - if (min_consumed_msg_id_ == QUEUE_INVALID_SEQ_ID || - min_consumed_msg_id_ < item.MsgIdEnd()) { + uint64_t first_seq_id = item.SeqId(); + STREAMING_LOG(INFO) << "TryEvictItems first_seq_id: " << first_seq_id + << " min_consumed_id_: " << min_consumed_id_ + << " eviction_limit_: " << eviction_limit_; + if (min_consumed_id_ == QUEUE_INVALID_SEQ_ID || first_seq_id > min_consumed_id_) { return Status::OutOfMemory("The queue is full and some reader doesn't consume"); } - if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || eviction_limit_ < item.MsgIdEnd()) { + if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || first_seq_id > eviction_limit_) { return Status::OutOfMemory("The queue is full and eviction limit block evict"); } - uint64_t evict_target_msg_id = std::min(min_consumed_msg_id_, eviction_limit_); + uint64_t evict_target_seq_id = std::min(min_consumed_id_, eviction_limit_); - int count = 0; - while (item.MsgIdEnd() <= evict_target_msg_id) { + while (item.SeqId() <= evict_target_seq_id) { PopProcessed(); - STREAMING_LOG(INFO) << "TryEvictItems directly " << item.MsgIdEnd(); + STREAMING_LOG(INFO) << "TryEvictItems directly " << item.SeqId(); item = FrontProcessed(); - count++; } - STREAMING_LOG(DEBUG) << count << " items evicted, current item: (" << item.MsgIdStart() - << "," << item.MsgIdEnd() << ")"; return Status::OK(); } void WriterQueue::OnNotify(std::shared_ptr notify_msg) { - STREAMING_LOG(INFO) << "OnNotify target msg_id: " << notify_msg->MsgId(); - min_consumed_msg_id_ = notify_msg->MsgId(); + STREAMING_LOG(INFO) << "OnNotify target seq_id: " << notify_msg->SeqId(); + min_consumed_id_ = notify_msg->SeqId(); } void WriterQueue::ResendItem(QueueItem &item, uint64_t first_seq_id, @@ -280,22 +273,22 @@ void WriterQueue::OnPull( }); } -void ReaderQueue::OnConsumed(uint64_t msg_id) { - STREAMING_LOG(INFO) << "OnConsumed: " << msg_id; +void ReaderQueue::OnConsumed(uint64_t seq_id) { + STREAMING_LOG(INFO) << "OnConsumed: " << seq_id; QueueItem item = FrontProcessed(); - while (item.MsgIdEnd() <= msg_id) { + while (item.SeqId() <= seq_id) { PopProcessed(); item = FrontProcessed(); } - Notify(msg_id); + Notify(seq_id); } -void ReaderQueue::Notify(uint64_t msg_id) { +void ReaderQueue::Notify(uint64_t seq_id) { std::vector task_args; - CreateNotifyTask(msg_id, task_args); + CreateNotifyTask(seq_id, task_args); // SubmitActorTask - NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, msg_id); + NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, seq_id); std::unique_ptr buffer = msg.ToBytes(); transport_->Send(std::move(buffer)); @@ -305,10 +298,7 @@ void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector &task_a void ReaderQueue::OnData(QueueItem &item) { last_recv_seq_id_ = item.SeqId(); - last_recv_msg_id_ = item.MsgIdEnd(); - STREAMING_LOG(DEBUG) << "ReaderQueue::OnData queue_id: " << queue_id_ - << " seq_id: " << last_recv_seq_id_ << " msg_id: (" - << item.MsgIdStart() << "," << item.MsgIdEnd() << ")"; + STREAMING_LOG(DEBUG) << "ReaderQueue::OnData seq_id: " << last_recv_seq_id_; Push(item); } diff --git a/streaming/src/queue/queue.h b/streaming/src/queue/queue.h index 758b90934..b15c3cfe8 100644 --- a/streaming/src/queue/queue.h +++ b/streaming/src/queue/queue.h @@ -94,10 +94,10 @@ class Queue { inline size_t Count() { return buffer_queue_.size(); } /// Return item count in pending state. - inline size_t PendingCount(); + size_t PendingCount(); /// Return item count in processed state. - inline size_t ProcessedCount(); + size_t ProcessedCount(); inline ActorID GetActorID() { return actor_id_; } inline ActorID GetPeerActorID() { return peer_actor_id_; } @@ -135,7 +135,7 @@ class WriterQueue : public Queue { peer_actor_id_(peer_actor_id), seq_id_(QUEUE_INITIAL_SEQ_ID), eviction_limit_(QUEUE_INVALID_SEQ_ID), - min_consumed_msg_id_(QUEUE_INVALID_SEQ_ID), + min_consumed_id_(QUEUE_INVALID_SEQ_ID), peer_last_msg_id_(0), peer_last_seq_id_(QUEUE_INVALID_SEQ_ID), transport_(transport), @@ -143,14 +143,12 @@ class WriterQueue : public Queue { is_upstream_first_pull_(true) {} /// Push a continuous buffer into queue, the buffer consists of some messages packed by - /// DataWriter. - /// \param data, the buffer address - /// \param data_size, buffer size - /// \param timestamp, the timestamp when the buffer pushed in - /// \param msg_id_start, the message id of the first message in the buffer - /// \param msg_id_end, the message id of the last message in the buffer - /// \param raw, whether this buffer is raw data, be True only in test - Status Push(uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, + /// DataWriter. \param data, the buffer address \param data_size, buffer size \param + /// timestamp, the timestamp when the buffer pushed in \param msg_id_start, the message + /// id of the first message in the buffer \param msg_id_end, the message id of the last + /// message in the buffer \param raw, whether this buffer is raw data, be True only in + /// test + Status Push(uint64_t seq_id, uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end, bool raw = false); /// Callback function, will be called when downstream queue notifies @@ -169,14 +167,16 @@ class WriterQueue : public Queue { void Send(); /// Called when user pushs item into queue. The count of items - /// can be evicted, determined by eviction_limit_ and min_consumed_msg_id_. + /// can be evicted, determined by eviction_limit_ and min_consumed_id_. Status TryEvictItems(); - void SetQueueEvictionLimit(uint64_t msg_id) { eviction_limit_ = msg_id; } + void SetQueueEvictionLimit(uint64_t eviction_limit) { + eviction_limit_ = eviction_limit; + } uint64_t EvictionLimit() { return eviction_limit_; } - uint64_t GetMinConsumedMsgID() { return min_consumed_msg_id_; } + uint64_t GetMinConsumedSeqID() { return min_consumed_id_; } void SetPeerLastIds(uint64_t msg_id, uint64_t seq_id) { peer_last_msg_id_ = msg_id; @@ -215,7 +215,7 @@ class WriterQueue : public Queue { ActorID peer_actor_id_; uint64_t seq_id_; uint64_t eviction_limit_; - uint64_t min_consumed_msg_id_; + uint64_t min_consumed_id_; uint64_t peer_last_msg_id_; uint64_t peer_last_seq_id_; std::shared_ptr transport_; @@ -238,8 +238,8 @@ class ReaderQueue : public Queue { transport), actor_id_(actor_id), peer_actor_id_(peer_actor_id), + min_consumed_id_(QUEUE_INVALID_SEQ_ID), last_recv_seq_id_(QUEUE_INVALID_SEQ_ID), - last_recv_msg_id_(QUEUE_INVALID_SEQ_ID), transport_(transport) {} /// Delete processed items whose seq id <= seq_id, @@ -252,8 +252,9 @@ class ReaderQueue : public Queue { /// NOTE: this callback function is called in queue thread. void OnResendData(std::shared_ptr msg); - inline uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; } - inline uint64_t GetLastRecvMsgId() { return last_recv_msg_id_; } + uint64_t GetMinConsumedSeqID() { return min_consumed_id_; } + + uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; } private: void Notify(uint64_t seq_id); @@ -262,8 +263,8 @@ class ReaderQueue : public Queue { private: ActorID actor_id_; ActorID peer_actor_id_; + uint64_t min_consumed_id_; uint64_t last_recv_seq_id_; - uint64_t last_recv_msg_id_; std::shared_ptr promise_for_pull_; std::shared_ptr transport_; }; diff --git a/streaming/src/queue/queue_handler.cc b/streaming/src/queue/queue_handler.cc index c6ef288be..40a6033d4 100644 --- a/streaming/src/queue/queue_handler.cc +++ b/streaming/src/queue/queue_handler.cc @@ -260,7 +260,7 @@ void UpstreamQueueMessageHandler::OnNotify( << queue::protobuf::StreamingQueueMessageType_Name( notify_msg->Type()) << ", maybe queue has been destroyed, ignore it." - << " msg id: " << notify_msg->MsgId(); + << " seq id: " << notify_msg->SeqId(); return; } queue->OnNotify(notify_msg); diff --git a/streaming/src/queue/queue_item.h b/streaming/src/queue/queue_item.h index b63e0eb74..f3954f346 100644 --- a/streaming/src/queue/queue_item.h +++ b/streaming/src/queue/queue_item.h @@ -24,7 +24,6 @@ const uint64_t QUEUE_INITIAL_SEQ_ID = 1; /// LocalMemoryBuffer shared_ptr, which will be sent out by Transport. class QueueItem { public: - QueueItem() = default; /// Construct a QueueItem object. /// \param[in] seq_id the sequential id assigned by DataWriter for a message bundle and /// QueueItem. diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc index 6bb378e20..cd30955fa 100644 --- a/streaming/src/queue/transport.cc +++ b/streaming/src/queue/transport.cc @@ -36,7 +36,7 @@ void Transport::SendInternal(std::shared_ptr buffer, } void Transport::Send(std::shared_ptr buffer) { - STREAMING_LOG(DEBUG) << "Transport::Send buffer size: " << buffer->Size(); + STREAMING_LOG(INFO) << "Transport::Send buffer size: " << buffer->Size(); std::vector return_ids; SendInternal(std::move(buffer), async_func_, TASK_OPTION_RETURN_NUM_0, return_ids); } diff --git a/streaming/src/reliability/barrier_helper.cc b/streaming/src/reliability/barrier_helper.cc deleted file mode 100644 index 14d66b790..000000000 --- a/streaming/src/reliability/barrier_helper.cc +++ /dev/null @@ -1,165 +0,0 @@ -#include "barrier_helper.h" - -#include - -#include "util/streaming_logging.h" -#include "util/streaming_util.h" - -namespace ray { -namespace streaming { -StreamingStatus StreamingBarrierHelper::GetMsgIdByBarrierId(const ObjectID &q_id, - uint64_t barrier_id, - uint64_t &msg_id) { - std::lock_guard lock(global_barrier_mutex_); - auto queue_map = global_barrier_map_.find(barrier_id); - if (queue_map == global_barrier_map_.end()) { - return StreamingStatus::NoSuchItem; - } - auto msg_id_map = queue_map->second.find(q_id); - if (msg_id_map == queue_map->second.end()) { - return StreamingStatus::QueueIdNotFound; - } - msg_id = msg_id_map->second; - return StreamingStatus::OK; -} - -void StreamingBarrierHelper::SetMsgIdByBarrierId(const ObjectID &q_id, - uint64_t barrier_id, uint64_t msg_id) { - std::lock_guard lock(global_barrier_mutex_); - global_barrier_map_[barrier_id][q_id] = msg_id; -} - -void StreamingBarrierHelper::ReleaseBarrierMapById(uint64_t barrier_id) { - std::lock_guard lock(global_barrier_mutex_); - global_barrier_map_.erase(barrier_id); -} - -void StreamingBarrierHelper::ReleaseAllBarrierMap() { - std::lock_guard lock(global_barrier_mutex_); - global_barrier_map_.clear(); -} - -void StreamingBarrierHelper::MapBarrierToCheckpoint(uint64_t barrier_id, - uint64_t checkpoint) { - std::lock_guard lock(barrier_map_checkpoint_mutex_); - barrier_checkpoint_map_[barrier_id] = checkpoint; -} - -StreamingStatus StreamingBarrierHelper::GetCheckpointIdByBarrierId( - uint64_t barrier_id, uint64_t &checkpoint_id) { - std::lock_guard lock(barrier_map_checkpoint_mutex_); - auto checkpoint_item = barrier_checkpoint_map_.find(barrier_id); - if (checkpoint_item == barrier_checkpoint_map_.end()) { - return StreamingStatus::NoSuchItem; - } - - checkpoint_id = checkpoint_item->second; - return StreamingStatus::OK; -} - -void StreamingBarrierHelper::ReleaseBarrierMapCheckpointByBarrierId( - const uint64_t barrier_id) { - std::lock_guard lock(barrier_map_checkpoint_mutex_); - auto it = barrier_checkpoint_map_.begin(); - while (it != barrier_checkpoint_map_.end()) { - if (it->first <= barrier_id) { - it = barrier_checkpoint_map_.erase(it); - } else { - it++; - } - } -} - -StreamingStatus StreamingBarrierHelper::GetBarrierIdByLastMessageId(const ObjectID &q_id, - uint64_t message_id, - uint64_t &barrier_id, - bool is_pop) { - std::lock_guard lock(message_id_map_barrier_mutex_); - auto message_item = global_reversed_barrier_map_.find(message_id); - if (message_item == global_reversed_barrier_map_.end()) { - return StreamingStatus::NoSuchItem; - } - - auto message_queue_item = message_item->second.find(q_id); - if (message_queue_item == message_item->second.end()) { - return StreamingStatus::QueueIdNotFound; - } - if (message_queue_item->second->empty()) { - STREAMING_LOG(WARNING) << "[Barrier] q id => " << q_id.Hex() << ", str num => " - << Util::Hexqid2str(q_id.Hex()) << ", message id " - << message_id; - return StreamingStatus::NoSuchItem; - } else { - barrier_id = message_queue_item->second->front(); - if (is_pop) { - message_queue_item->second->pop(); - } - } - return StreamingStatus::OK; -} - -void StreamingBarrierHelper::SetBarrierIdByLastMessageId(const ObjectID &q_id, - uint64_t message_id, - uint64_t barrier_id) { - std::lock_guard lock(message_id_map_barrier_mutex_); - - auto max_message_id_barrier = max_message_id_map_.find(q_id); - // remove finished barrier in different last message id - if (max_message_id_barrier != max_message_id_map_.end() && - max_message_id_barrier->second != message_id) { - if (global_reversed_barrier_map_.find(max_message_id_barrier->second) != - global_reversed_barrier_map_.end()) { - global_reversed_barrier_map_.erase(max_message_id_barrier->second); - } - } - - max_message_id_map_[q_id] = message_id; - auto message_item = global_reversed_barrier_map_.find(message_id); - if (message_item == global_reversed_barrier_map_.end()) { - BarrierIdQueue temp_queue = std::make_shared>(); - temp_queue->push(barrier_id); - global_reversed_barrier_map_[message_id][q_id] = temp_queue; - return; - } - auto message_queue_item = message_item->second.find(q_id); - if (message_queue_item != message_item->second.end()) { - message_queue_item->second->push(barrier_id); - } else { - BarrierIdQueue temp_queue = std::make_shared>(); - temp_queue->push(barrier_id); - global_reversed_barrier_map_[message_id][q_id] = temp_queue; - } -} - -void StreamingBarrierHelper::GetAllBarrier(std::vector &barrier_id_vec) { - std::transform( - global_barrier_map_.begin(), global_barrier_map_.end(), - std::back_inserter(barrier_id_vec), - [](std::unordered_map>::value_type - pair) { return pair.first; }); -} - -bool StreamingBarrierHelper::Contains(uint64_t barrier_id) { - return global_barrier_map_.find(barrier_id) != global_barrier_map_.end(); -} - -uint32_t StreamingBarrierHelper::GetBarrierMapSize() { - return global_barrier_map_.size(); -} - -void StreamingBarrierHelper::GetCurrentMaxCheckpointIdInQueue( - const ObjectID &q_id, uint64_t &checkpoint_id) const { - auto item = current_max_checkpoint_id_map_.find(q_id); - if (item != current_max_checkpoint_id_map_.end()) { - checkpoint_id = item->second; - } else { - checkpoint_id = 0; - } -} - -void StreamingBarrierHelper::SetCurrentMaxCheckpointIdInQueue( - const ObjectID &q_id, const uint64_t checkpoint_id) { - current_max_checkpoint_id_map_[q_id] = checkpoint_id; -} -} // namespace streaming -} // namespace ray diff --git a/streaming/src/reliability/barrier_helper.h b/streaming/src/reliability/barrier_helper.h deleted file mode 100644 index 1ceacc47a..000000000 --- a/streaming/src/reliability/barrier_helper.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once -#include -#include - -#include "common/status.h" -#include "ray/common/id.h" - -namespace ray { -namespace streaming { -class StreamingBarrierHelper { - using BarrierIdQueue = std::shared_ptr>; - - private: - // Global barrier map set (global barrier id -> (channel id -> msg id)) - std::unordered_map> - global_barrier_map_; - - // Message id map to barrier id of each queue(continuous barriers hold same last message - // id) - // message id -> (queue id -> list(barrier id)). - // Thread unsafe to assign value in user's thread but collect it in loopforward thread. - std::unordered_map> - global_reversed_barrier_map_; - - std::unordered_map barrier_checkpoint_map_; - - std::unordered_map max_message_id_map_; - - // We assume default max checkpoint is 0. - std::unordered_map current_max_checkpoint_id_map_; - - std::mutex message_id_map_barrier_mutex_; - - std::mutex global_barrier_mutex_; - - std::mutex barrier_map_checkpoint_mutex_; - - public: - StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, - uint64_t &msg_id); - void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t seq_id); - bool Contains(uint64_t barrier_id); - void ReleaseBarrierMapById(uint64_t barrier_id); - void ReleaseAllBarrierMap(); - void GetAllBarrier(std::vector &barrier_id_vec); - uint32_t GetBarrierMapSize(); - - void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint); - StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id, - uint64_t &checkpoint_id); - void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id); - - StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, - uint64_t &barrier_id, bool is_pop = false); - void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, - uint64_t barrier_id); - - void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, - uint64_t &checkpoint_id) const; - - void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, - const uint64_t checkpoint_id); -}; -} // namespace streaming -} // namespace ray diff --git a/streaming/src/reliability_helper.cc b/streaming/src/reliability_helper.cc deleted file mode 100644 index 9e4a083ab..000000000 --- a/streaming/src/reliability_helper.cc +++ /dev/null @@ -1,113 +0,0 @@ -#include "reliability_helper.h" - -#include -namespace ray { -namespace streaming { - -std::shared_ptr ReliabilityHelperFactory::CreateReliabilityHelper( - const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader) { - if (config.IsExactlyOnce()) { - return std::make_shared(config, barrier_helper, writer, reader); - } else { - return std::make_shared(config, barrier_helper, writer, reader); - } -} - -ReliabilityHelper::ReliabilityHelper(const StreamingConfig &config, - StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader) - : config_(config), - barrier_helper_(barrier_helper), - writer_(writer), - reader_(reader) {} - -void ReliabilityHelper::Reload() {} - -bool ReliabilityHelper::StoreBundleMeta(ProducerChannelInfo &channel_info, - StreamingMessageBundlePtr &bundle_ptr, - bool is_replay) { - return false; -} - -bool ReliabilityHelper::FilterMessage(ProducerChannelInfo &channel_info, - const uint8_t *data, - StreamingMessageType message_type, - uint64_t *write_message_id) { - bool is_filtered = false; - uint64_t &message_id = channel_info.current_message_id; - uint64_t last_msg_id = channel_info.message_last_commit_id; - - if (StreamingMessageType::Barrier == message_type) { - is_filtered = message_id < last_msg_id; - } else { - message_id++; - // Message last commit id is the last item in queue or restore from queue. - // It skip directly since message id is less or equal than current commit id. - is_filtered = message_id <= last_msg_id && !config_.IsAtLeastOnce(); - } - *write_message_id = message_id; - - return is_filtered; -} - -void ReliabilityHelper::CleanupCheckpoint(ProducerChannelInfo &channel_info, - uint64_t barrier_id) {} - -StreamingStatus ReliabilityHelper::InitChannelMerger(uint32_t timeout) { - return reader_->InitChannelMerger(timeout); -} - -StreamingStatus ReliabilityHelper::HandleNoValidItem(ConsumerChannelInfo &channel_info) { - STREAMING_LOG(DEBUG) << "[Reader] Queue " << channel_info.channel_id - << " get item timeout, resend notify " - << channel_info.current_message_id; - reader_->NotifyConsumedItem(channel_info, channel_info.current_message_id); - return StreamingStatus::OK; -} - -AtLeastOnceHelper::AtLeastOnceHelper(const StreamingConfig &config, - StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader) - : ReliabilityHelper(config, barrier_helper, writer, reader) {} - -StreamingStatus AtLeastOnceHelper::InitChannelMerger(uint32_t timeout) { - // No merge in AT_LEAST_ONCE - return StreamingStatus::OK; -} - -StreamingStatus AtLeastOnceHelper::HandleNoValidItem(ConsumerChannelInfo &channel_info) { - if (current_sys_time_ms() - channel_info.resend_notify_timer > - StreamingConfig::RESEND_NOTIFY_MAX_INTERVAL) { - STREAMING_LOG(INFO) << "[Reader] Queue " << channel_info.channel_id - << " get item timeout, resend notify " - << channel_info.current_message_id; - reader_->NotifyConsumedItem(channel_info, channel_info.current_message_id); - channel_info.resend_notify_timer = current_sys_time_ms(); - } - return StreamingStatus::Invalid; -} - -ExactlyOnceHelper::ExactlyOnceHelper(const StreamingConfig &config, - StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader) - : ReliabilityHelper(config, barrier_helper, writer, reader) {} - -bool ExactlyOnceHelper::FilterMessage(ProducerChannelInfo &channel_info, - const uint8_t *data, - StreamingMessageType message_type, - uint64_t *write_message_id) { - bool is_filtered = ReliabilityHelper::FilterMessage(channel_info, data, message_type, - write_message_id); - if (is_filtered && StreamingMessageType::Barrier == message_type && - StreamingRole::SOURCE == config_.GetStreamingRole()) { - *write_message_id = channel_info.message_last_commit_id; - // Do not skip source barrier when it's reconstructing from downstream. - is_filtered = false; - STREAMING_LOG(INFO) << "append barrier to buffer ring " << *write_message_id - << ", last commit id " << channel_info.message_last_commit_id; - } - return is_filtered; -} -} // namespace streaming -} // namespace ray diff --git a/streaming/src/reliability_helper.h b/streaming/src/reliability_helper.h deleted file mode 100644 index 56089a085..000000000 --- a/streaming/src/reliability_helper.h +++ /dev/null @@ -1,66 +0,0 @@ -#pragma once -#include "channel/channel.h" -#include "data_reader.h" -#include "data_writer.h" -#include "reliability/barrier_helper.h" -#include "util/config.h" - -namespace ray { -namespace streaming { - -class ReliabilityHelper; -class DataWriter; -class DataReader; - -class ReliabilityHelperFactory { - public: - static std::shared_ptr CreateReliabilityHelper( - const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader); -}; - -class ReliabilityHelper { - public: - ReliabilityHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader); - virtual ~ReliabilityHelper() = default; - // Only exactly same need override this function. - virtual void Reload(); - // Store bundle meta or skip in replay mode. - virtual bool StoreBundleMeta(ProducerChannelInfo &channel_info, - StreamingMessageBundlePtr &bundle_ptr, - bool is_replay = false); - virtual void CleanupCheckpoint(ProducerChannelInfo &channel_info, uint64_t barrier_id); - // Filter message by different failover strategies. - virtual bool FilterMessage(ProducerChannelInfo &channel_info, const uint8_t *data, - StreamingMessageType message_type, - uint64_t *write_message_id); - virtual StreamingStatus InitChannelMerger(uint32_t timeout); - virtual StreamingStatus HandleNoValidItem(ConsumerChannelInfo &channel_info); - - protected: - const StreamingConfig &config_; - StreamingBarrierHelper &barrier_helper_; - DataWriter *writer_; - DataReader *reader_; -}; - -class AtLeastOnceHelper : public ReliabilityHelper { - public: - AtLeastOnceHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader); - StreamingStatus InitChannelMerger(uint32_t timeout) override; - StreamingStatus HandleNoValidItem(ConsumerChannelInfo &channel_info) override; -}; - -class ExactlyOnceHelper : public ReliabilityHelper { - public: - ExactlyOnceHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, - DataWriter *writer, DataReader *reader); - bool FilterMessage(ProducerChannelInfo &channel_info, const uint8_t *data, - StreamingMessageType message_type, - uint64_t *write_message_id) override; - virtual ~ExactlyOnceHelper() = default; -}; -} // namespace streaming -} // namespace ray diff --git a/streaming/src/ring_buffer/ring_buffer.cc b/streaming/src/ring_buffer.cc similarity index 100% rename from streaming/src/ring_buffer/ring_buffer.cc rename to streaming/src/ring_buffer.cc diff --git a/streaming/src/ring_buffer/ring_buffer.h b/streaming/src/ring_buffer.h similarity index 100% rename from streaming/src/ring_buffer/ring_buffer.h rename to streaming/src/ring_buffer.h diff --git a/streaming/src/runtime_context.h b/streaming/src/runtime_context.h index a86ebbcd1..4b6f49ab8 100644 --- a/streaming/src/runtime_context.h +++ b/streaming/src/runtime_context.h @@ -2,8 +2,8 @@ #include -#include "common/status.h" #include "config/streaming_config.h" +#include "status.h" namespace ray { namespace streaming { diff --git a/streaming/src/common/status.h b/streaming/src/status.h similarity index 97% rename from streaming/src/common/status.h rename to streaming/src/status.h index 63a1cbaee..dde6f386a 100644 --- a/streaming/src/common/status.h +++ b/streaming/src/status.h @@ -19,6 +19,7 @@ enum class StreamingStatus : uint32_t { GetBundleTimeOut = 9, SkipSendEmptyMessage = 10, Interrupted = 11, + WaitQueueTimeOut = 12, OutOfMemory = 13, Invalid = 14, UnknownError = 15, diff --git a/streaming/src/test/message_serialization_tests.cc b/streaming/src/test/message_serialization_tests.cc index 14dfc0232..b94064591 100644 --- a/streaming/src/test/message_serialization_tests.cc +++ b/streaming/src/test/message_serialization_tests.cc @@ -80,7 +80,7 @@ TEST(StreamingSerializationTest, streaming_message_barrier_bundle_serialization_ auto s_item = s_message_list.back(); EXPECT_TRUE(s_item->ClassBytesSize() == m_item->ClassBytesSize()); EXPECT_TRUE(s_item->GetMessageType() == m_item->GetMessageType()); - EXPECT_TRUE(s_item->GetMessageId() == m_item->GetMessageId()); + EXPECT_TRUE(s_item->GetMessageSeqId() == m_item->GetMessageSeqId()); EXPECT_TRUE(s_item->GetDataSize() == m_item->GetDataSize()); EXPECT_TRUE( std::memcmp(s_item->RawData(), m_item->RawData(), m_item->GetDataSize()) == 0); diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 4d3c652a8..9bbcad803 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -67,13 +67,27 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { } private: - void StreamingWriterExactlyOnceTest() { - StreamingConfig config; - StreamingWriterStrategyTest(config); + void TestWriteMessageToBufferRing(std::shared_ptr writer_client, + std::vector &q_list) { + // const uint8_t temp_data[] = {1, 2, 4, 5}; - STREAMING_LOG(INFO) - << "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest"; - status_ = true; + uint32_t i = 1; + while (i <= MESSAGE_BOUND_SIZE) { + for (auto &q_id : q_list) { + uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE); + uint8_t *data = new uint8_t[buffer_len]; + for (uint32_t j = 0; j < buffer_len; ++j) { + data[j] = j % 128; + } + + writer_client->WriteMessageToBufferRing(q_id, data, buffer_len, + StreamingMessageType::Message); + } + ++i; + } + + // Wait a while + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); } void StreamingWriterStrategyTest(StreamingConfig &config) { @@ -97,7 +111,6 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { std::shared_ptr runtime_context(new RuntimeContext()); runtime_context->SetConfig(config); - // Create writer. std::shared_ptr streaming_writer_client(new DataWriter(runtime_context)); uint64_t queue_size = 10 * 1000 * 1000; std::vector channel_seq_id_vec(queue_ids_.size(), 0); @@ -106,35 +119,22 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { STREAMING_LOG(INFO) << "streaming_writer_client Init done"; streaming_writer_client->Run(); - - // Write some data. std::thread test_loop_thread( &StreamingQueueWriterTestSuite::TestWriteMessageToBufferRing, this, streaming_writer_client, std::ref(queue_ids_)); + // test_loop_thread.detach(); if (test_loop_thread.joinable()) { test_loop_thread.join(); } } - void TestWriteMessageToBufferRing(std::shared_ptr writer_client, - std::vector &q_list) { - uint32_t i = 1; - while (i <= MESSAGE_BOUND_SIZE) { - for (auto &q_id : q_list) { - uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE); - uint8_t *data = new uint8_t[buffer_len]; - for (uint32_t j = 0; j < buffer_len; ++j) { - data[j] = j % 128; - } + void StreamingWriterExactlyOnceTest() { + StreamingConfig config; + StreamingWriterStrategyTest(config); - writer_client->WriteMessageToBufferRing(q_id, data, buffer_len, - StreamingMessageType::Message); - } - ++i; - } - STREAMING_LOG(INFO) << "Write data done."; - // Wait a while. - std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + STREAMING_LOG(INFO) + << "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest"; + status_ = true; } }; @@ -180,7 +180,7 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { for (auto &q_id : queue_id_vec) { reader_client->NotifyConsumedItem((*offset_map)[q_id], - (*offset_map)[q_id].current_message_id); + (*offset_map)[q_id].current_seq_id); } // writer_client->ClearCheckpoint(msg->last_barrier_id); @@ -201,7 +201,7 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { recevied_message_cnt += message_list.size(); for (auto &item : message_list) { - uint64_t i = item->GetMessageId(); + uint64_t i = item->GetMessageSeqId(); uint32_t buff_len = i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE; if (i > MESSAGE_BOUND_SIZE) break; @@ -270,7 +270,7 @@ class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { } void GetQueueTest() { - // Sleep 2s, queue shoulde not exist when reader pull. + // Sleep 2s, queue shoulde not exist when reader pull std::this_thread::sleep_for(std::chrono::milliseconds(2000)); auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); ObjectID &queue_id = queue_ids_[0]; @@ -297,7 +297,7 @@ class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { } void PullPeerAsyncTest() { - // Sleep 2s, queue should not exist when reader pull. + // Sleep 2s, queue should not exist when reader pull std::this_thread::sleep_for(std::chrono::milliseconds(2000)); auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); ObjectID &queue_id = queue_ids_[0]; @@ -323,8 +323,10 @@ class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { uint8_t data[100]; memset(data, msg_id, 100); STREAMING_LOG(INFO) << "Writer User Push item msg_id: " << msg_id; - ASSERT_TRUE( - queue->Push(data, 100, current_sys_time_ms(), msg_id, msg_id, true).ok()); + ASSERT_TRUE(queue + ->Push(msg_id /*seqid*/, data, 100, current_sys_time_ms(), msg_id, + msg_id, true) + .ok()); queue->Send(); } diff --git a/streaming/src/test/mock_transfer_tests.cc b/streaming/src/test/mock_transfer_tests.cc index 62085ac73..f6268ec81 100644 --- a/streaming/src/test/mock_transfer_tests.cc +++ b/streaming/src/test/mock_transfer_tests.cc @@ -10,7 +10,7 @@ TEST(StreamingMockTransfer, mock_produce_consume) { ObjectID channel_id = ObjectID::FromRandom(); ProducerChannelInfo producer_channel_info; producer_channel_info.channel_id = channel_id; - producer_channel_info.current_message_id = 0; + producer_channel_info.current_seq_id = 0; MockProducer producer(transfer_config, producer_channel_info); ConsumerChannelInfo consumer_channel_info; @@ -22,12 +22,15 @@ TEST(StreamingMockTransfer, mock_produce_consume) { producer.ProduceItemToChannel(data, 3); uint8_t *data_consumed; uint32_t data_size_consumed; - consumer.ConsumeItemFromChannel(data_consumed, data_size_consumed, -1); + uint64_t data_seq_id; + consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1); EXPECT_EQ(data_size_consumed, 3); + EXPECT_EQ(data_seq_id, 1); EXPECT_EQ(std::memcmp(data_consumed, data, 3), 0); consumer.NotifyChannelConsumed(1); - auto status = consumer.ConsumeItemFromChannel(data_consumed, data_size_consumed, -1); + auto status = + consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1); EXPECT_EQ(status, StreamingStatus::NoSuchItem); } @@ -49,9 +52,8 @@ class StreamingTransferTest : public ::testing::Test { std::vector channel_id_vec(queue_vec.size(), 0); std::vector queue_size_vec(queue_vec.size(), 10000); std::vector params(queue_vec.size()); - std::vector creation_status; writer->Init(queue_vec, params, channel_id_vec, queue_size_vec); - reader->Init(queue_vec, params, channel_id_vec, creation_status, -1); + reader->Init(queue_vec, params, channel_id_vec, queue_size_vec, -1); } void DestroyTransfer() { writer.reset(); @@ -150,21 +152,18 @@ TEST_F(StreamingTransferTest, flow_control_test) { reader->GetOffsetInfo(reader_offset_info); uint32_t writer_step = writer_runtime_context->GetConfig().GetWriterConsumedStep(); uint32_t reader_step = reader_runtime_context->GetConfig().GetReaderConsumedStep(); - uint64_t &writer_current_msg_id = + uint64_t &writer_current_seq_id = (*writer_offset_info)[queue_vec[0]].current_seq_id; + uint64_t &writer_current_message_id = (*writer_offset_info)[queue_vec[0]].current_message_id; - uint64_t &writer_last_commit_id = - (*writer_offset_info)[queue_vec[0]].message_last_commit_id; - uint64_t &writer_target_msg_id = - (*writer_offset_info)[queue_vec[0]].queue_info.target_message_id; - uint64_t &reader_target_msg_id = - (*reader_offset_info)[queue_vec[0]].queue_info.target_message_id; - do { + uint64_t &reader_target_seq_id = + (*reader_offset_info)[queue_vec[0]].queue_info.target_seq_id; + while (writer_current_seq_id < writer_step) { + STREAMING_LOG(INFO) << "Writer currrent seq id " << writer_current_seq_id + << " message " << writer_current_message_id << " consumer step " + << writer_step; std::this_thread::sleep_for( std::chrono::milliseconds(StreamingConfig::TIME_WAIT_UINT)); - STREAMING_LOG(INFO) << "Writer currrent msg id " << writer_current_msg_id - << ", writer target_msg_id=" << writer_target_msg_id - << ", consumer step " << writer_step; - } while (writer_current_msg_id < writer_step); + } std::list read_message_list; while (read_message_list.size() < num) { @@ -174,8 +173,8 @@ TEST_F(StreamingTransferTest, flow_control_test) { auto &message_list = bundle_ptr->GetMessageList(); std::copy(message_list.begin(), message_list.end(), std::back_inserter(read_message_list)); - ASSERT_GE(writer_step, writer_last_commit_id - msg->meta->GetLastMessageId()); - ASSERT_GE(msg->meta->GetLastMessageId() + reader_step, reader_target_msg_id); + ASSERT_GE(writer_step, writer_current_seq_id - msg->seq_id); + ASSERT_GE(msg->seq_id + reader_step, reader_target_seq_id); } int index = 0; for (auto &message : read_message_list) { diff --git a/streaming/src/test/run_streaming_queue_test.sh b/streaming/src/test/run_streaming_queue_test.sh index 752c95b2b..c53e6295f 100755 --- a/streaming/src/test/run_streaming_queue_test.sh +++ b/streaming/src/test/run_streaming_queue_test.sh @@ -44,22 +44,15 @@ if [ ! -d "$RAY_ROOT/python" ]; then exit 1 fi -REDIS_MODULE="$RAY_ROOT/bazel-bin/libray_redis_module.so" -REDIS_SERVER_EXEC="$RAY_ROOT/bazel-bin/external/com_github_antirez_redis/redis-server" -STORE_EXEC="$RAY_ROOT/bazel-bin/plasma_store_server" -REDIS_CLIENT_EXEC="$RAY_ROOT/bazel-bin/redis-cli" -RAYLET_EXEC="$RAY_ROOT/bazel-bin/raylet" -STREAMING_TEST_WORKER_EXEC="$RAY_ROOT/bazel-bin/streaming/streaming_test_worker" -GCS_SERVER_EXEC="$RAY_ROOT/bazel-bin/gcs_server" - -# clear env -pgrep "plasma|DefaultDriver|DefaultWorker|AppStarter|redis|http_server|job_agent" | xargs kill -9 &> /dev/null +REDIS_MODULE="./bazel-bin/libray_redis_module.so" +REDIS_SERVER_EXEC="./bazel-bin/external/com_github_antirez_redis/redis-server" +STORE_EXEC="./bazel-bin/plasma_store_server" +REDIS_CLIENT_EXEC="./bazel-bin/redis-cli" +RAYLET_EXEC="./bazel-bin/raylet" +STREAMING_TEST_WORKER_EXEC="./bazel-bin/streaming/streaming_test_worker" +GCS_SERVER_EXEC="./bazel-bin/gcs_server" +# Allow cleanup commands to fail. # Run tests. - -# to run specific test, add --gtest_filter, below is an example -#$RAY_ROOT/bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC $GCS_SERVER_EXEC $REDIS_SERVER_EXEC $REDIS_MODULE $REDIS_CLIENT_EXEC --gtest_filter=StreamingTest/StreamingWriterTest.streaming_writer_exactly_once_test/0 - -# run all tests -"$RAY_ROOT"/bazel-bin/streaming/streaming_queue_tests "$STORE_EXEC" "$RAYLET_EXEC" "$RAYLET_PORT" "$STREAMING_TEST_WORKER_EXEC" "$GCS_SERVER_EXEC" "$REDIS_SERVER_EXEC" "$REDIS_MODULE" "$REDIS_CLIENT_EXEC" +./bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC "$RAYLET_PORT" $STREAMING_TEST_WORKER_EXEC $GCS_SERVER_EXEC $REDIS_SERVER_EXEC $REDIS_MODULE $REDIS_CLIENT_EXEC sleep 1s diff --git a/streaming/src/test/streaming_queue_tests.cc b/streaming/src/test/streaming_queue_tests.cc index f45e2a45c..c2c678315 100644 --- a/streaming/src/test/streaming_queue_tests.cc +++ b/streaming/src/test/streaming_queue_tests.cc @@ -66,6 +66,7 @@ INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingExactlySameTest, } // namespace ray int main(int argc, char **argv) { + // set_streaming_log_config("streaming_writer_test", StreamingLogLevel::INFO, 0); ::testing::InitGoogleTest(&argc, argv); RAY_CHECK(argc == 9); ray::TEST_STORE_EXEC_PATH = std::string(argv[1]); diff --git a/streaming/src/util/config.cc b/streaming/src/util/config.cc deleted file mode 100644 index b5c3e320a..000000000 --- a/streaming/src/util/config.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "config.h" -namespace ray { -namespace streaming { - -boost::any &Config::Get(ConfigEnum key) const { - auto item = config_map_.find(key); - STREAMING_CHECK(item != config_map_.end()); - return item->second; -} - -boost::any Config::Get(ConfigEnum key, boost::any default_value) const { - auto item = config_map_.find(key); - if (item == config_map_.end()) { - return default_value; - } - return item->second; -} - -} // namespace streaming -} // namespace ray diff --git a/streaming/src/util/config.h b/streaming/src/util/config.h deleted file mode 100644 index 56c6af81f..000000000 --- a/streaming/src/util/config.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once -#include -#include - -#include "streaming_logging.h" - -namespace ray { -namespace streaming { -enum class ConfigEnum : uint32_t { - QUEUE_ID_VECTOR = 0, - MIN = QUEUE_ID_VECTOR, - MAX = QUEUE_ID_VECTOR -}; -} -} // namespace ray - -namespace std { -template <> -struct hash<::ray::streaming::ConfigEnum> { - size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { - return static_cast(config_enum_key); - } -}; - -template <> -struct hash { - size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { - return static_cast(config_enum_key); - } -}; -} // namespace std - -namespace ray { -namespace streaming { - -class Config { - public: - template - inline void Set(ConfigEnum key, const ValueType &any) { - config_map_.emplace(key, any); - } - - template - inline void Set(ConfigEnum key, ValueType &&any) { - config_map_.emplace(key, any); - } - - template - inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) { - auto item = config_map_.find(key); - if (item != config_map_.end()) { - return item->second; - } - Set(key, any); - return any; - } - - boost::any &Get(ConfigEnum key) const; - boost::any Get(ConfigEnum key, boost::any default_value) const; - - inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline double GetDouble(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline bool GetBool(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline std::string GetString(ConfigEnum key) { - return boost::any_cast(Get(key)); - } - - virtual ~Config() = default; - - protected: - mutable std::unordered_map config_map_; -}; - -} // namespace streaming -} // namespace ray diff --git a/streaming/src/util/streaming_util.cc b/streaming/src/util/streaming_util.cc index 4f2a13535..95038f2a9 100644 --- a/streaming/src/util/streaming_util.cc +++ b/streaming/src/util/streaming_util.cc @@ -3,6 +3,21 @@ #include namespace ray { namespace streaming { + +boost::any &Config::Get(ConfigEnum key) const { + auto item = config_map_.find(key); + STREAMING_CHECK(item != config_map_.end()); + return item->second; +} + +boost::any Config::Get(ConfigEnum key, boost::any default_value) const { + auto item = config_map_.find(key); + if (item == config_map_.end()) { + return default_value; + } + return item->second; +} + std::string Util::Byte2hex(const uint8_t *data, uint32_t data_size) { constexpr char hex[] = "0123456789abcdef"; std::string result; diff --git a/streaming/src/util/streaming_util.h b/streaming/src/util/streaming_util.h index ecb5403bd..8f28dc3ec 100644 --- a/streaming/src/util/streaming_util.h +++ b/streaming/src/util/streaming_util.h @@ -4,80 +4,94 @@ #include #include -#include "ray/common/id.h" #include "util/streaming_logging.h" namespace ray { namespace streaming { +enum class ConfigEnum : uint32_t { + QUEUE_ID_VECTOR = 0, + RECONSTRUCT_RETRY_TIMES, + RECONSTRUCT_TIMEOUT_PER_MB, + CURRENT_DRIVER_ID, + /// For direct call + CORE_WORKER, + SYNC_FUNCTION, + ASYNC_FUNCTION, + TRANSFER_MIN = QUEUE_ID_VECTOR, + TRANSFER_MAX = ASYNC_FUNCTION +}; +} // namespace streaming +} // namespace ray + +namespace std { +template <> +struct hash<::ray::streaming::ConfigEnum> { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; + +template <> +struct hash { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; +} // namespace std + +namespace ray { +namespace streaming { + +class Config { + public: + template + inline void Set(ConfigEnum key, const ValueType &any) { + config_map_.emplace(key, any); + } + + template + inline void Set(ConfigEnum key, ValueType &&any) { + config_map_.emplace(key, any); + } + + template + inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) { + auto item = config_map_.find(key); + if (item != config_map_.end()) { + return item->second; + } + Set(key, any); + return any; + } + + boost::any &Get(ConfigEnum key) const; + + boost::any Get(ConfigEnum key, boost::any default_value) const; + + inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline double GetDouble(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline bool GetBool(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline std::string GetString(ConfigEnum key) { + return boost::any_cast(Get(key)); + } + + virtual ~Config() = default; + + protected: + mutable std::unordered_map config_map_; +}; + class Util { public: static std::string Byte2hex(const uint8_t *data, uint32_t data_size); static std::string Hexqid2str(const std::string &q_id_hex); - - template - static std::string join(const T &v, const std::string &delimiter, - const std::string &prefix = "", - const std::string &suffix = "") { - std::stringstream ss; - size_t i = 0; - ss << prefix; - for (const auto &elem : v) { - if (i != 0) { - ss << delimiter; - } - ss << elem; - i++; - } - ss << suffix; - return ss.str(); - } - - template - static std::string join(InputIterator first, InputIterator last, - const std::string &delim, const std::string &arround = "") { - std::string a = arround; - while (first != last) { - a += std::to_string(*first); - first++; - if (first != last) a += delim; - } - a += arround; - return a; - } - - template - static std::string join(InputIterator first, InputIterator last, - std::function func, - const std::string &delim, const std::string &arround = "") { - std::string a = arround; - while (first != last) { - a += func(first); - first++; - if (first != last) a += delim; - } - a += arround; - return a; - } }; - -class AutoSpinLock { - public: - explicit AutoSpinLock(std::atomic_flag &lock) : lock_(lock) { - while (lock_.test_and_set(std::memory_order_acquire)) - ; - } - ~AutoSpinLock() { unlock(); } - void unlock() { lock_.clear(std::memory_order_release); } - - private: - std::atomic_flag &lock_; -}; - -inline void ConvertToValidQueueId(const ObjectID &queue_id) { - auto addr = const_cast(&queue_id); - *(reinterpret_cast(addr)) = 0; -} } // namespace streaming } // namespace ray