diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel index 269e10433..6bb71ae98 100644 --- a/streaming/BUILD.bazel +++ b/streaming/BUILD.bazel @@ -9,6 +9,9 @@ proto_library( srcs = ["src/protobuf/streaming.proto"], strip_import_prefix = "src", visibility = ["//visibility:public"], + deps = [ + "@com_google_protobuf//:any_proto", + ], ) proto_library( @@ -22,7 +25,10 @@ proto_library( srcs = ["src/protobuf/remote_call.proto"], strip_import_prefix = "src", visibility = ["//visibility:public"], - deps = ["streaming_proto"], + deps = [ + "streaming_proto", + "@com_google_protobuf//:any_proto", + ], ) cc_proto_library( @@ -70,9 +76,10 @@ cc_library( "src/util/*.h", ]), copts = COPTS, - strip_include_prefix = "src", + includes = ["src"], visibility = ["//visibility:public"], deps = [ + "ray_common.so", "ray_util.so", "@boost//:any", "@com_google_googletest//:gtest", @@ -143,6 +150,62 @@ cc_library( }), ) +cc_library( + name = "streaming_channel", + srcs = glob(["src/channel/*.cc"]), + hdrs = glob(["src/channel/*.h"]), + copts = COPTS, + visibility = ["//visibility:public"], + deps = [ + ":streaming_common", + ":streaming_message", + ":streaming_queue", + ":streaming_ring_buffer", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_reliability", + srcs = glob(["src/reliability/*.cc"]), + hdrs = glob(["src/reliability/*.h"]), + copts = COPTS, + includes = ["src/"], + visibility = ["//visibility:public"], + deps = [ + ":streaming_channel", + ":streaming_message", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_ring_buffer", + srcs = glob(["src/ring_buffer/*.cc"]), + hdrs = glob(["src/ring_buffer/*.h"]), + copts = COPTS, + includes = ["src/"], + visibility = ["//visibility:public"], + deps = [ + "core_worker_lib.so", + ":ray_common.so", + ":ray_util.so", + ":streaming_message", + "@boost//:circular_buffer", + "@boost//:thread", + ], +) + +cc_library( + name = "streaming_common", + srcs = glob(["src/common/*.cc"]), + hdrs = glob(["src/common/*.h"]), + copts = COPTS, + includes = ["src/"], + visibility = ["//visibility:public"], + deps = [], +) + cc_library( name = "streaming_lib", srcs = glob([ @@ -159,11 +222,13 @@ cc_library( deps = [ "ray_common.so", "ray_util.so", + ":streaming_channel", + ":streaming_common", ":streaming_config", ":streaming_message", ":streaming_queue", + ":streaming_reliability", ":streaming_util", - "@boost//:circular_buffer", ], ) @@ -284,6 +349,7 @@ genrule( mkdir -p "$$GENERATED_DIR" touch "$$GENERATED_DIR/__init__.py" sed -i -E 's/from streaming.src.protobuf/from ./' "$$GENERATED_DIR/remote_call_pb2.py" + sed -i -E 's/from protobuf/from ./' "$$GENERATED_DIR/remote_call_pb2.py" date > $@ """, local = 1, @@ -298,7 +364,6 @@ cc_binary( ]), copts = COPTS, linkshared = 1, - linkstatic = 1, visibility = ["//visibility:public"], deps = [ ":streaming_lib", diff --git a/streaming/java/BUILD.bazel b/streaming/java/BUILD.bazel index 91de8130e..8ee94c597 100644 --- a/streaming/java/BUILD.bazel +++ b/streaming/java/BUILD.bazel @@ -127,10 +127,12 @@ define_java_module( ":io_ray_ray_streaming-state", "//java:io_ray_ray_api", "//java:io_ray_ray_runtime", + "@maven//:commons_io_commons_io", "@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java", "@ray_streaming_maven//:com_google_code_findbugs_jsr305", "@ray_streaming_maven//:com_google_guava_guava", "@ray_streaming_maven//:com_google_protobuf_protobuf_java", + "@ray_streaming_maven//:commons_collections_commons_collections", "@ray_streaming_maven//:de_ruedigermoeller_fst", "@ray_streaming_maven//:org_aeonbits_owner_owner", "@ray_streaming_maven//:org_apache_commons_commons_lang3", diff --git a/streaming/java/checkstyle-suppressions.xml b/streaming/java/checkstyle-suppressions.xml index cb07198ed..1d86cdba0 100644 --- a/streaming/java/checkstyle-suppressions.xml +++ b/streaming/java/checkstyle-suppressions.xml @@ -11,4 +11,7 @@ + + + diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl index 1fe083f99..b834f6a39 100644 --- a/streaming/java/dependencies.bzl +++ b/streaming/java/dependencies.bzl @@ -25,6 +25,7 @@ def gen_streaming_java_deps(): "org.mockito:mockito-all:1.10.19", "org.powermock:powermock-module-testng:1.6.6", "org.powermock:powermock-api-mockito:1.6.6", + "commons-collections:commons-collections:3.2.1", ], repositories = [ "https://repo.spring.io/plugins-release/", diff --git a/streaming/java/generate_jni_header_files.sh b/streaming/java/generate_jni_header_files.sh new file mode 100755 index 000000000..5ce3cb7a2 --- /dev/null +++ b/streaming/java/generate_jni_header_files.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +set -e +set -x + +cd "$(dirname "$0")" + +bazel build all_streaming_tests_deploy.jar + +function generate_one() +{ + file=${1//./_}.h + javah -classpath ../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar "$1" + + # prepend licence first + cat < ../src/lib/java/"$file" +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +EOF + # then append the generated header file + cat "$file" >> ../src/lib/java/"$file" + rm -f "$file" +} + +generate_one io.ray.streaming.runtime.transfer.channel.ChannelId +generate_one io.ray.streaming.runtime.transfer.DataReader +generate_one io.ray.streaming.runtime.transfer.DataWriter +generate_one io.ray.streaming.runtime.transfer.TransferHandler + +rm -f io_ray_streaming_*.h diff --git a/streaming/java/pom.xml b/streaming/java/pom.xml index 3432e006e..003c7670a 100644 --- a/streaming/java/pom.xml +++ b/streaming/java/pom.xml @@ -65,7 +65,6 @@ 27.0.1-jre 2.57 - release diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java index f29915381..82791a622 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/Function.java @@ -7,4 +7,26 @@ import java.io.Serializable; */ public interface Function extends Serializable { + /** + * This method will be called periodically by framework, you should return a a serializable + * object which represents function state, framework will help you to serialize this object, save + * it to storage, and load it back when in fail-over through. + * {@link Function#loadCheckpoint(Serializable)}. + * + * @return A serializable object which represents function state. + */ + default Serializable saveCheckpoint() { + return null; + } + + /** + * This method will be called by framework when a worker died and been restarted. + * We will pass the last object you returned in {@link Function#saveCheckpoint()} when + * doing checkpoint, you are responsible to load this object back to you function. + * + * @param checkpointObject the last object you returned in {@link Function#saveCheckpoint()} + */ + default void loadCheckpoint(Serializable checkpointObject) { + } + } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java index 40135b34b..96900d841 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/impl/SourceFunction.java @@ -9,9 +9,9 @@ import io.ray.streaming.api.function.Function; */ public interface SourceFunction extends Function { - void init(int parallel, int index); + void init(int parallelism, int index); - void run(SourceContext ctx) throws Exception; + void fetch(SourceContext ctx) throws Exception; void close(); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java index b14aa9a6c..ec63b7d7e 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/api/function/internal/CollectionSourceFunction.java @@ -1,7 +1,6 @@ package io.ray.streaming.api.function.internal; import io.ray.streaming.api.function.impl.SourceFunction; -import java.util.ArrayList; import java.util.Collection; /** @@ -12,22 +11,25 @@ import java.util.Collection; public class CollectionSourceFunction implements SourceFunction { private Collection values; + private boolean finished = false; public CollectionSourceFunction(Collection values) { this.values = values; } @Override - public void init(int parallel, int index) { + public void init(int totalParallel, int currentIndex) { } @Override - public void run(SourceContext ctx) throws Exception { + public void fetch(SourceContext ctx) throws Exception { + if (finished) { + return; + } for (T value : values) { ctx.collect(value); } - // empty collection - values = new ArrayList<>(); + finished = true; } @Override diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java index c99ec9959..4eb655689 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/message/KeyRecord.java @@ -1,6 +1,5 @@ package io.ray.streaming.message; - import java.util.Objects; public class KeyRecord extends Record { diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java index 0bbb0d7a2..d054b95a7 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/Operator.java @@ -25,4 +25,13 @@ public interface Operator extends Serializable { ChainStrategy getChainStrategy(); + /** + * See {@link Function#saveCheckpoint()}. + */ + Serializable saveCheckpoint(); + + /** + * See {@link Function#loadCheckpoint(Serializable)}. + */ + void loadCheckpoint(Serializable checkpointObject); } diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java index 3cf9ab1d7..11f35f495 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/SourceOperator.java @@ -4,7 +4,7 @@ import io.ray.streaming.api.function.impl.SourceFunction.SourceContext; public interface SourceOperator extends Operator { - void run(); + void fetch(); SourceContext getSourceContext(); diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java index 67bc77381..fda6c5d0e 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/StreamOperator.java @@ -8,6 +8,7 @@ import io.ray.streaming.api.function.RichFunction; import io.ray.streaming.api.function.internal.Functions; import io.ray.streaming.message.KeyRecord; import io.ray.streaming.message.Record; +import java.io.Serializable; import java.util.List; public abstract class StreamOperator implements Operator { @@ -72,6 +73,16 @@ public abstract class StreamOperator implements Operator { } } + @Override + public Serializable saveCheckpoint() { + return function.saveCheckpoint(); + } + + @Override + public void loadCheckpoint(Serializable checkpointObject) { + function.loadCheckpoint(checkpointObject); + } + @Override public String getName() { return name; diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java index c7c9e7a18..3a4e32cbb 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/chain/ChainedOperator.java @@ -13,6 +13,7 @@ import io.ray.streaming.operator.OperatorType; import io.ray.streaming.operator.SourceOperator; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.operator.TwoInputOperator; +import java.io.Serializable; import java.lang.reflect.Proxy; import java.util.Collections; import java.util.List; @@ -85,6 +86,23 @@ public abstract class ChainedOperator extends StreamOperator { return tailOperator; } + @Override + public Serializable saveCheckpoint() { + Object[] checkpoints = new Object[operators.size()]; + for (int i = 0; i < operators.size(); ++i) { + checkpoints[i] = operators.get(i).saveCheckpoint(); + } + return checkpoints; + } + + @Override + public void loadCheckpoint(Serializable checkpointObject) { + Serializable[] checkpoints = (Serializable[]) checkpointObject; + for (int i = 0; i < operators.size(); ++i) { + operators.get(i).loadCheckpoint(checkpoints[i]); + } + } + private RuntimeContext createRuntimeContext(RuntimeContext runtimeContext, int index) { return (RuntimeContext) Proxy.newProxyInstance(runtimeContext.getClass().getClassLoader(), new Class[] {RuntimeContext.class}, @@ -125,8 +143,8 @@ public abstract class ChainedOperator extends StreamOperator { } @Override - public void run() { - sourceOperator.run(); + public void fetch() { + sourceOperator.fetch(); } @Override diff --git a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java index 495604c3a..120701d88 100644 --- a/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java +++ b/streaming/java/streaming-api/src/main/java/io/ray/streaming/operator/impl/SourceOperatorImpl.java @@ -29,9 +29,9 @@ public class SourceOperatorImpl extends StreamOperator> } @Override - public void run() { + public void fetch() { try { - this.function.run(this.sourceContext); + this.function.fetch(this.sourceContext); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/streaming/java/streaming-runtime/pom.xml b/streaming/java/streaming-runtime/pom.xml index 6f63fe147..86a11ec8d 100644 --- a/streaming/java/streaming-runtime/pom.xml +++ b/streaming/java/streaming-runtime/pom.xml @@ -1,8 +1,8 @@ - + + xmlns="http://maven.apache.org/POM/4.0.0" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> ray-streaming io.ray @@ -69,6 +69,16 @@ protobuf-java 3.8.0 + + commons-collections + commons-collections + 3.2.1 + + + commons-io + commons-io + 2.5 + de.ruedigermoeller fst diff --git a/streaming/java/streaming-runtime/pom_template.xml b/streaming/java/streaming-runtime/pom_template.xml index c1f890c12..7b8da2a9e 100644 --- a/streaming/java/streaming-runtime/pom_template.xml +++ b/streaming/java/streaming-runtime/pom_template.xml @@ -1,8 +1,8 @@ - {auto_gen_header} + {auto_gen_header} + xmlns="http://maven.apache.org/POM/4.0.0" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> ray-streaming io.ray diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java index ac7ede0a9..3f1697149 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/StreamingGlobalConfig.java @@ -1,7 +1,9 @@ package io.ray.streaming.runtime.config; import com.google.common.base.Preconditions; +import io.ray.streaming.runtime.config.global.CheckpointConfig; import io.ray.streaming.runtime.config.global.CommonConfig; +import io.ray.streaming.runtime.config.global.ContextBackendConfig; import io.ray.streaming.runtime.config.global.TransferConfig; import java.io.Serializable; import java.lang.reflect.Method; @@ -19,17 +21,19 @@ import org.slf4j.LoggerFactory; public class StreamingGlobalConfig implements Serializable { private static final Logger LOG = LoggerFactory.getLogger(StreamingGlobalConfig.class); - public final CommonConfig commonConfig; public final TransferConfig transferConfig; - public final Map configMap; + public CheckpointConfig checkpointConfig; + public ContextBackendConfig contextBackendConfig; public StreamingGlobalConfig(final Map conf) { configMap = new HashMap<>(conf); commonConfig = ConfigFactory.create(CommonConfig.class, conf); transferConfig = ConfigFactory.create(TransferConfig.class, conf); + checkpointConfig = ConfigFactory.create(CheckpointConfig.class, conf); + contextBackendConfig = ConfigFactory.create(ContextBackendConfig.class, conf); globalConfig2Map(); } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java new file mode 100644 index 000000000..b31bc7d8c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/CheckpointConfig.java @@ -0,0 +1,55 @@ +package io.ray.streaming.runtime.config.global; + +import io.ray.streaming.runtime.config.Config; +import org.aeonbits.owner.Mutable; + +/** + * Configurations for checkpointing. + */ +public interface CheckpointConfig extends Config, Mutable { + + String CP_INTERVAL_SECS = "streaming.checkpoint.interval.secs"; + String CP_TIMEOUT_SECS = "streaming.checkpoint.timeout.secs"; + + String CP_PREFIX_KEY_MASTER = "streaming.checkpoint.prefix-key.job-master.context"; + String CP_PREFIX_KEY_WORKER = "streaming.checkpoint.prefix-key.job-worker.context"; + String CP_PREFIX_KEY_OPERATOR = "streaming.checkpoint.prefix-key.job-worker.operator"; + + /** + * Checkpoint time interval. JobMaster won't trigger 2 checkpoint in less than this time interval. + */ + @DefaultValue(value = "5") + @Key(value = CP_INTERVAL_SECS) + int cpIntervalSecs(); + + /** + * How long should JobMaster wait for checkpoint to finish. When this timeout is reached and + * JobMaster hasn't received all commits from workers, JobMaster will consider this checkpoint as + * failed and trigger another checkpoint. + */ + @DefaultValue(value = "120") + @Key(value = CP_TIMEOUT_SECS) + int cpTimeoutSecs(); + + /** + * This is used for saving JobMaster's context to storage, user usually don't need to change this. + */ + @DefaultValue(value = "job_master_runtime_context_") + @Key(value = CP_PREFIX_KEY_MASTER) + String jobMasterContextCpPrefixKey(); + + /** + * This is used for saving JobWorker's context to storage, user usually don't need to change this. + */ + @DefaultValue(value = "job_worker_context_") + @Key(value = CP_PREFIX_KEY_WORKER) + String jobWorkerContextCpPrefixKey(); + + /** + * This is used for saving user operator(in StreamTask)'s context to storage, user usually don't + * need to change this. + */ + @DefaultValue(value = "job_worker_op_") + @Key(value = CP_PREFIX_KEY_OPERATOR) + String jobWorkerOpCpPrefixKey(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java new file mode 100644 index 000000000..11d1d3371 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/ContextBackendConfig.java @@ -0,0 +1,17 @@ +package io.ray.streaming.runtime.config.global; + +import org.aeonbits.owner.Config; + +public interface ContextBackendConfig extends Config { + + String STATE_BACKEND_TYPE = "streaming.context-backend.type"; + String FILE_STATE_ROOT_PATH = "streaming.context-backend.file-state.root"; + + @Config.DefaultValue(value = "memory") + @Key(value = STATE_BACKEND_TYPE) + String stateBackendType(); + + @Config.DefaultValue(value = "/tmp/ray_streaming_state") + @Key(value = FILE_STATE_ROOT_PATH) + String fileStateRootPath(); +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java index 7508dee2d..e6ea60d7a 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/global/TransferConfig.java @@ -22,13 +22,6 @@ public interface TransferConfig extends Config { @Key(value = io.ray.streaming.util.Config.CHANNEL_SIZE) long channelSize(); - /** - * DataRead read timeout. - */ - @DefaultValue(value = "false") - @Key(value = io.ray.streaming.util.Config.IS_RECREATE) - boolean readerIsRecreate(); - /** * Return from DataReader.getBundle if only empty message read in this interval. */ diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java new file mode 100644 index 000000000..329e88c9a --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/config/types/ContextBackendType.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.config.types; + +public enum ContextBackendType { + + /** + * Memory type + */ + MEMORY("memory", 0), + + /** + * Local File + */ + LOCAL_FILE("local_file", 1); + + private String name; + private int index; + + ContextBackendType(String name, int index) { + this.name = name; + this.index = index; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java new file mode 100644 index 000000000..b14cdcbb9 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackend.java @@ -0,0 +1,42 @@ +package io.ray.streaming.runtime.context; + +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.worker.JobWorker; + +/** + * This interface is used for storing context of {@link JobWorker} and {@link JobMaster}. + * The checkpoint returned by user function is also saved using this interface. + */ +public interface ContextBackend { + + /** + * check if key exists in state + * + * @return true if exists + */ + boolean exists(final String key) throws Exception; + + /** + * get content by key + * + * @param key key + * @return the StateBackend + */ + byte[] get(final String key) throws Exception; + + /** + * put content by key + * + * @param key key + * @param value content + */ + void put(final String key, final byte[] value) throws Exception; + + /** + * remove content by key + * + * @param key key + */ + void remove(final String key) throws Exception; + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java new file mode 100644 index 000000000..2ca96b5de --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/ContextBackendFactory.java @@ -0,0 +1,27 @@ +package io.ray.streaming.runtime.context; + +import io.ray.streaming.runtime.config.StreamingGlobalConfig; +import io.ray.streaming.runtime.config.types.ContextBackendType; +import io.ray.streaming.runtime.context.impl.AtomicFsBackend; +import io.ray.streaming.runtime.context.impl.MemoryContextBackend; + +public class ContextBackendFactory { + + public static ContextBackend getContextBackend(final StreamingGlobalConfig config) { + ContextBackend contextBackend; + ContextBackendType type = ContextBackendType.valueOf( + config.contextBackendConfig.stateBackendType().toUpperCase()); + + switch (type) { + case MEMORY: + contextBackend = new MemoryContextBackend(config.contextBackendConfig); + break; + case LOCAL_FILE: + contextBackend = new AtomicFsBackend(config.contextBackendConfig); + break; + default: + throw new RuntimeException("Unsupported context backend type."); + } + return contextBackend; + } +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java new file mode 100644 index 000000000..85bceb13c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/OperatorCheckpointInfo.java @@ -0,0 +1,52 @@ +package io.ray.streaming.runtime.context; + +import com.google.common.base.MoreObjects; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +/** + * This data structure contains state information of a task. + */ +public class OperatorCheckpointInfo implements Serializable { + + /** + * key: channel ID, value: offset + */ + public Map inputPoints; + public Map outputPoints; + + /** + * a serializable checkpoint returned by processor + */ + public Serializable processorCheckpoint; + public long checkpointId; + + public OperatorCheckpointInfo() { + inputPoints = new HashMap<>(); + outputPoints = new HashMap<>(); + checkpointId = -1; + } + + public OperatorCheckpointInfo( + Map inputPoints, + Map outputPoints, + Serializable processorCheckpoint, + long checkpointId) { + this.inputPoints = inputPoints; + this.outputPoints = outputPoints; + this.checkpointId = checkpointId; + this.processorCheckpoint = processorCheckpoint; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("inputPoints", inputPoints) + .add("outputPoints", outputPoints) + .add("processorCheckpoint", processorCheckpoint) + .add("checkpointId", checkpointId) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java new file mode 100644 index 000000000..96288e281 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/AtomicFsBackend.java @@ -0,0 +1,48 @@ +package io.ray.streaming.runtime.context.impl; + +import io.ray.streaming.runtime.config.global.ContextBackendConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Achieves an atomic `put` method. + * known issue: if you crashed while write a key at first time, this code will not work. + */ +public class AtomicFsBackend extends LocalFileContextBackend { + + private static final Logger LOG = LoggerFactory.getLogger(AtomicFsBackend.class); + private static final String TMP_FLAG = "_tmp"; + + public AtomicFsBackend(final ContextBackendConfig config) { + super(config); + } + + @Override + public byte[] get(String key) throws Exception { + String tmpKey = key + TMP_FLAG; + if (super.exists(tmpKey) && !super.exists(key)) { + return super.get(tmpKey); + } + return super.get(key); + } + + @Override + public void put(String key, byte[] value) throws Exception { + String tmpKey = key + TMP_FLAG; + if (super.exists(tmpKey) && !super.exists(key)) { + super.rename(tmpKey, key); + } + super.put(tmpKey, value); + super.remove(key); + super.rename(tmpKey, key); + } + + @Override + public void remove(String key) { + String tmpKey = key + TMP_FLAG; + if (super.exists(tmpKey)) { + super.remove(tmpKey); + } + super.remove(key); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java new file mode 100644 index 000000000..41e180462 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/LocalFileContextBackend.java @@ -0,0 +1,55 @@ +package io.ray.streaming.runtime.context.impl; + +import io.ray.streaming.runtime.config.global.ContextBackendConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import java.io.File; +import org.apache.commons.io.FileUtils; + +/** + * This context backend uses local file system and doesn't supports failover in cluster. + * But it supports failover in single node. + * This is a pure file system backend which doesn't support atomic writing, please don't use this + * class, instead, use {@link AtomicFsBackend} which extends this class. + */ +public class LocalFileContextBackend implements ContextBackend { + + private final String rootPath; + + + public LocalFileContextBackend(ContextBackendConfig config) { + rootPath = config.fileStateRootPath(); + } + + @Override + public boolean exists(String key) { + File file = new File(rootPath, key); + return file.exists(); + } + + @Override + public byte[] get(String key) throws Exception { + File file = new File(rootPath, key); + if (file.exists()) { + return FileUtils.readFileToByteArray(file); + } + return null; + } + + @Override + public void put(String key, byte[] value) throws Exception { + File file = new File(rootPath, key); + FileUtils.writeByteArrayToFile(file, value); + } + + @Override + public void remove(String key) { + File file = new File(rootPath, key); + FileUtils.deleteQuietly(file); + } + + protected void rename(String fromKey, String toKey) throws Exception { + File srcFile = new File(rootPath, fromKey); + File dstFile = new File(rootPath, toKey); + FileUtils.moveFile(srcFile, dstFile); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java new file mode 100644 index 000000000..0a3723e05 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/context/impl/MemoryContextBackend.java @@ -0,0 +1,72 @@ +package io.ray.streaming.runtime.context.impl; + +import io.ray.streaming.runtime.config.global.ContextBackendConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This context backend uses memory and doesn't supports failover. + * Data will be lost after worker died. + */ +public class MemoryContextBackend implements ContextBackend { + + private static final Logger LOG = LoggerFactory.getLogger(MemoryContextBackend.class); + + private final Map kvStore = new HashMap<>(); + + public MemoryContextBackend(ContextBackendConfig config) { + if (LOG.isInfoEnabled()) { + LOG.info("Start init memory state backend, config is {}.", config); + LOG.info("Finish init memory state backend."); + } + } + + @Override + public boolean exists(String key) { + return kvStore.containsKey(key); + } + + @Override + public byte[] get(final String key) { + if (LOG.isInfoEnabled()) { + LOG.info("Get value of key {} start.", key); + } + + byte[] readData = kvStore.get(key); + + if (LOG.isInfoEnabled()) { + LOG.info("Get value of key {} success.", key); + } + + return readData; + } + + @Override + public void put(final String key, final byte[] value) { + if (LOG.isInfoEnabled()) { + LOG.info("Put value of key {} start.", key); + } + + kvStore.put(key, value); + + if (LOG.isInfoEnabled()) { + LOG.info("Put value of key {} success.", key); + } + } + + @Override + public void remove(final String key) { + if (LOG.isInfoEnabled()) { + LOG.info("Remove value of key {} start.", key); + } + + kvStore.remove(key); + + if (LOG.isInfoEnabled()) { + LOG.info("Remove value of key {} success.", key); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java index 1006165f3..90a8e2b18 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/collector/OutputCollector.java @@ -9,8 +9,8 @@ import io.ray.streaming.message.Record; import io.ray.streaming.runtime.serialization.CrossLangSerializer; import io.ray.streaming.runtime.serialization.JavaSerializer; import io.ray.streaming.runtime.serialization.Serializer; -import io.ray.streaming.runtime.transfer.ChannelId; import io.ray.streaming.runtime.transfer.DataWriter; +import io.ray.streaming.runtime.transfer.channel.ChannelId; import java.nio.ByteBuffer; import java.util.Collection; import org.slf4j.Logger; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java index b9d19bf2d..dcbf6b1ff 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionGraph.java @@ -1,19 +1,29 @@ package io.ray.streaming.runtime.core.graph.executiongraph; +import com.google.common.collect.Sets; import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Physical plan. */ public class ExecutionGraph implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(ExecutionGraph.class); + /** * Name of the job. */ @@ -29,6 +39,27 @@ public class ExecutionGraph implements Serializable { */ private Map executionJobVertexMap; + /** + * Data map for execution vertex. + * key: execution vertex id. + * value: execution vertex. + */ + private Map executionVertexMap; + + /** + * Data map for execution vertex. + * key: actor id. + * value: execution vertex. + */ + private Map actorIdExecutionVertexMap; + + + /** + * key: channel ID + * value: actors in both sides of this channel + */ + private Map> channelGroupedActors; + /** * The max parallelism of the whole graph. */ @@ -54,7 +85,7 @@ public class ExecutionGraph implements Serializable { } public List getExecutionJobVertexList() { - return new ArrayList(executionJobVertexMap.values()); + return new ArrayList<>(executionJobVertexMap.values()); } public Map getExecutionJobVertexMap() { @@ -65,6 +96,58 @@ public class ExecutionGraph implements Serializable { this.executionJobVertexMap = executionJobVertexMap; } + + /** + * generate relation mappings between actors, execution vertices and channels + * this method must be called after worker actor is set. + */ + public void generateActorMappings() { + LOG.info("Setup queue actors relation."); + + channelGroupedActors = new HashMap<>(); + actorIdExecutionVertexMap = new HashMap<>(); + + getAllExecutionVertices().forEach(curVertex -> { + + // current + actorIdExecutionVertexMap.put(curVertex.getActorId(), curVertex); + + // input + List inputEdges = curVertex.getInputEdges(); + inputEdges.forEach(inputEdge -> { + ExecutionVertex inputVertex = inputEdge.getSourceExecutionVertex(); + String channelId = curVertex.getChannelIdByPeerVertex(inputVertex); + addActorToChannelGroupedActors(channelGroupedActors, channelId, + inputVertex.getWorkerActor()); + }); + + // output + List outputEdges = curVertex.getOutputEdges(); + outputEdges.forEach(outputEdge -> { + ExecutionVertex outputVertex = outputEdge.getTargetExecutionVertex(); + String channelId = curVertex.getChannelIdByPeerVertex(outputVertex); + addActorToChannelGroupedActors(channelGroupedActors, channelId, + outputVertex.getWorkerActor()); + }); + }); + + LOG.debug("Channel grouped actors is: {}.", channelGroupedActors); + } + + private void addActorToChannelGroupedActors( + Map> channelGroupedActors, + String queueName, + BaseActorHandle actor) { + + Set actorSet = + channelGroupedActors.computeIfAbsent(queueName, k -> new HashSet<>()); + actorSet.add(actor); + } + + public void setExecutionVertexMap(Map executionVertexMap) { + this.executionVertexMap = executionVertexMap; + } + public Map getJobConfig() { return jobConfig; } @@ -114,25 +197,73 @@ public class ExecutionGraph implements Serializable { return executionJobVertexMap.values().stream() .map(ExecutionJobVertex::getExecutionVertices) .flatMap(Collection::stream) - .filter(vertex -> vertex.is2Add()) + .filter(ExecutionVertex::is2Add) .collect(Collectors.toList()); } /** * Get specified execution vertex from current execution graph by execution vertex id. * - * @param vertexId execution vertex id. + * @param executionVertexId execution vertex id. * @return the specified execution vertex. */ - public ExecutionVertex getExecutionJobVertexByJobVertexId(int vertexId) { - for (ExecutionJobVertex executionJobVertex : executionJobVertexMap.values()) { - for (ExecutionVertex executionVertex : executionJobVertex.getExecutionVertices()) { - if (executionVertex.getExecutionVertexId() == vertexId) { - return executionVertex; - } - } + public ExecutionVertex getExecutionVertexByExecutionVertexId(int executionVertexId) { + if (executionVertexMap.containsKey(executionVertexId)) { + return executionVertexMap.get(executionVertexId); } - throw new RuntimeException("Vertex " + vertexId + " does not exist!"); + throw new RuntimeException("Vertex " + executionVertexId + " does not exist!"); + } + + + /** + * Get specified execution vertex from current execution graph by actor id. + * + * @param actorId the actor id of execution vertex. + * @return the specified execution vertex. + */ + public ExecutionVertex getExecutionVertexByActorId(ActorId actorId) { + return actorIdExecutionVertexMap.get(actorId); + } + + + /** + * Get specified actor by actor id. + * + * @param actorId the actor id of execution vertex. + * @return the specified actor handle. + */ + public Optional getActorById(ActorId actorId) { + return getAllActors().stream() + .filter(actor -> actor.getId().equals(actorId)) + .findFirst(); + } + + /** + * Get the peer actor in the other side of channelName of a given actor + * + * @param actor actor in this side + * @param channelName the channel name + * @return the peer actor in the other side + */ + public BaseActorHandle getPeerActor(BaseActorHandle actor, String channelName) { + Set set = getActorsByChannelId(channelName); + final BaseActorHandle[] res = new BaseActorHandle[1]; + set.forEach(anActor -> { + if (!anActor.equals(actor)) { + res[0] = anActor; + } + }); + return res[0]; + } + + /** + * Get actors in both sides of a channelId + * + * @param channelId the channelId + * @return actors in both sides + */ + public Set getActorsByChannelId(String channelId) { + return channelGroupedActors.getOrDefault(channelId, Sets.newHashSet()); } /** @@ -202,4 +333,27 @@ public class ExecutionGraph implements Serializable { .collect(Collectors.toList()); } + public Set getActorName(Set actorIds) { + return getAllExecutionVertices().stream() + .filter(executionVertex -> actorIds.contains(executionVertex.getActorId())) + .map(ExecutionVertex::getActorName) + .collect(Collectors.toSet()); + } + + public String getActorName(ActorId actorId) { + Set set = Sets.newHashSet(); + set.add(actorId); + Set result = getActorName(set); + if (result.isEmpty()) { + return null; + } + return result.iterator().next(); + } + + public List getAllActorsId() { + return getAllActors().stream() + .map(BaseActorHandle::getId) + .collect(Collectors.toList()); + } + } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java index 6ab2fd911..6aa7936b2 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobEdge.java @@ -3,11 +3,12 @@ package io.ray.streaming.runtime.core.graph.executiongraph; import com.google.common.base.MoreObjects; import io.ray.streaming.api.partition.Partition; import io.ray.streaming.jobgraph.JobEdge; +import java.io.Serializable; /** * An edge that connects two execution job vertices. */ -public class ExecutionJobEdge { +public class ExecutionJobEdge implements Serializable { /** * The source(upstream) execution job vertex. diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java index f0c87bd0f..b617cc053 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionJobVertex.java @@ -8,6 +8,7 @@ import io.ray.streaming.jobgraph.JobVertex; import io.ray.streaming.jobgraph.VertexType; import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.runtime.config.master.ResourceConfig; +import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -20,7 +21,7 @@ import org.aeonbits.owner.ConfigFactory; *

Execution job vertex is the physical form of {@link JobVertex} and * every execution job vertex is corresponding to a group of {@link ExecutionVertex}. */ -public class ExecutionJobVertex { +public class ExecutionJobVertex implements Serializable { /** * Unique id. Use {@link JobVertex}'s id directly. diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java index 0135b35ed..5d6a2556c 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/graph/executiongraph/ExecutionVertex.java @@ -9,6 +9,7 @@ import io.ray.streaming.operator.StreamOperator; import io.ray.streaming.runtime.config.master.ResourceConfig; import io.ray.streaming.runtime.core.resource.ContainerId; import io.ray.streaming.runtime.core.resource.ResourceType; +import io.ray.streaming.runtime.transfer.channel.ChannelId; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; @@ -60,6 +61,8 @@ public class ExecutionVertex implements Serializable { */ private ContainerId containerId; + private String pid; + /** * Worker actor handle. */ @@ -73,6 +76,14 @@ public class ExecutionVertex implements Serializable { private List inputEdges = new ArrayList<>(); private List outputEdges = new ArrayList<>(); + private transient List outputChannelIdList; + private transient List inputChannelIdList; + + private transient List outputActorList; + private transient List inputActorList; + private Map exeVertexChannelMap; + + public ExecutionVertex( int globalIndex, int index, @@ -92,9 +103,7 @@ public class ExecutionVertex implements Serializable { } private Map genWorkerConfig(Map jobConfig) { - Map workerConfig = new HashMap<>(); - workerConfig.putAll(jobConfig); - return workerConfig; + return new HashMap<>(jobConfig); } public int getExecutionVertexId() { @@ -161,14 +170,14 @@ public class ExecutionVertex implements Serializable { return workerActor; } - public ActorId getWorkerActorId() { - return workerActor.getId(); - } - public void setWorkerActor(BaseActorHandle workerActor) { this.workerActor = workerActor; } + public ActorId getWorkerActorId() { + return workerActor.getId(); + } + public List getInputEdges() { return inputEdges; } @@ -199,6 +208,14 @@ public class ExecutionVertex implements Serializable { .collect(Collectors.toList()); } + public ActorId getActorId() { + return null == workerActor ? null : workerActor.getId(); + } + + public String getActorName() { + return String.valueOf(executionVertexId); + } + public Map getResource() { return resource; } @@ -219,12 +236,89 @@ public class ExecutionVertex implements Serializable { this.containerId = containerId; } + public String getPid() { + return pid; + } + + public void setPid(String pid) { + this.pid = pid; + } + public void setContainerIfNotExist(ContainerId containerId) { if (null == this.containerId) { this.containerId = containerId; } } + /*---------channel-actor relations---------*/ + public List getOutputChannelIdList() { + if (outputChannelIdList == null) { + generateActorChannelInfo(); + } + return outputChannelIdList; + } + + public List getOutputActorList() { + if (outputActorList == null) { + generateActorChannelInfo(); + } + return outputActorList; + } + + public List getInputChannelIdList() { + if (inputChannelIdList == null) { + generateActorChannelInfo(); + } + return inputChannelIdList; + } + + public List getInputActorList() { + if (inputActorList == null) { + generateActorChannelInfo(); + } + return inputActorList; + } + + + public String getChannelIdByPeerVertex(ExecutionVertex peerVertex) { + if (exeVertexChannelMap == null) { + generateActorChannelInfo(); + } + return exeVertexChannelMap.get(peerVertex.getExecutionVertexId()); + } + + + private void generateActorChannelInfo() { + inputChannelIdList = new ArrayList<>(); + inputActorList = new ArrayList<>(); + outputChannelIdList = new ArrayList<>(); + outputActorList = new ArrayList<>(); + exeVertexChannelMap = new HashMap<>(); + + List inputEdges = getInputEdges(); + for (ExecutionEdge edge : inputEdges) { + String channelId = ChannelId.genIdStr( + edge.getSourceExecutionVertex().getExecutionVertexId(), + getExecutionVertexId(), + getBuildTime()); + inputChannelIdList.add(channelId); + inputActorList.add(edge.getSourceExecutionVertex().getWorkerActor()); + exeVertexChannelMap.put(edge.getSourceExecutionVertex().getExecutionVertexId(), channelId); + } + + List outputEdges = getOutputEdges(); + for (ExecutionEdge edge : outputEdges) { + String channelId = ChannelId.genIdStr( + getExecutionVertexId(), + edge.getTargetExecutionVertex().getExecutionVertexId(), + getBuildTime()); + outputChannelIdList.add(channelId); + outputActorList.add(edge.getTargetExecutionVertex().getWorkerActor()); + exeVertexChannelMap.put(edge.getTargetExecutionVertex().getExecutionVertexId(), channelId); + } + } + + private Map generateResources(ResourceConfig resourceConfig) { Map resourceMap = new HashMap<>(); if (resourceConfig.isTaskCpuResourceLimit()) { diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java index 500721d3d..d189c42c1 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/ProcessBuilder.java @@ -15,7 +15,7 @@ public class ProcessBuilder { public static StreamProcessor buildProcessor(StreamOperator streamOperator) { OperatorType type = streamOperator.getOpType(); LOGGER.info("Building StreamProcessor, operator type = {}, operator = {}.", type, - streamOperator.getClass().getSimpleName().toString()); + streamOperator.getClass().getSimpleName()); switch (type) { case SOURCE: return new SourceProcessor<>((SourceOperator) streamOperator); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java index 3b128376c..54fe76cd8 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/Processor.java @@ -2,6 +2,7 @@ package io.ray.streaming.runtime.core.processor; import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.context.RuntimeContext; +import io.ray.streaming.api.function.Function; import java.io.Serializable; import java.util.List; @@ -11,5 +12,15 @@ public interface Processor extends Serializable { void process(T t); + /** + * See {@link Function#saveCheckpoint()}. + */ + Serializable saveCheckpoint(); + + /** + * See {@link Function#loadCheckpoint(Serializable)}. + */ + void loadCheckpoint(Serializable checkpointObject); + void close(); } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java index 020f39d16..1cc721a2a 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/core/processor/SourceProcessor.java @@ -19,8 +19,8 @@ public class SourceProcessor extends StreamProcessor implements Processo LOGGER.info("opened {}", this); } + @Override + public Serializable saveCheckpoint() { + return operator.saveCheckpoint(); + } + + @Override + public void loadCheckpoint(Serializable checkpointObject) { + operator.loadCheckpoint(checkpointObject); + } + @Override public String toString() { return this.getClass().getSimpleName(); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java index 60d3a0843..6115e4d50 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/JobMaster.java @@ -1,18 +1,36 @@ package io.ray.streaming.runtime.master; import com.google.common.base.Preconditions; +import com.google.protobuf.InvalidProtocolBufferException; import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.Ray; +import io.ray.api.id.ActorId; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.runtime.config.StreamingConfig; import io.ray.streaming.runtime.config.StreamingMasterConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import io.ray.streaming.runtime.context.ContextBackendFactory; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.coordinator.CheckpointCoordinator; +import io.ray.streaming.runtime.master.coordinator.FailoverCoordinator; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; import io.ray.streaming.runtime.master.graphmanager.GraphManager; import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; import io.ray.streaming.runtime.master.resourcemanager.ResourceManager; import io.ray.streaming.runtime.master.resourcemanager.ResourceManagerImpl; import io.ray.streaming.runtime.master.scheduler.JobSchedulerImpl; +import io.ray.streaming.runtime.util.CheckpointStateUtil; +import io.ray.streaming.runtime.util.ResourceUtil; +import io.ray.streaming.runtime.util.Serializer; import io.ray.streaming.runtime.worker.JobWorker; import java.util.Map; +import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,33 +42,68 @@ public class JobMaster { private static final Logger LOG = LoggerFactory.getLogger(JobMaster.class); - private JobRuntimeContext runtimeContext; + private JobMasterRuntimeContext runtimeContext; private ResourceManager resourceManager; private JobSchedulerImpl scheduler; private GraphManager graphManager; private StreamingMasterConfig conf; + private ContextBackend contextBackend; + private ActorHandle jobMasterActor; + // coordinators + private CheckpointCoordinator checkpointCoordinator; + private FailoverCoordinator failoverCoordinator; + public JobMaster(Map confMap) { LOG.info("Creating job master with conf: {}.", confMap); StreamingConfig streamingConfig = new StreamingConfig(confMap); this.conf = streamingConfig.masterConfig; + this.contextBackend = ContextBackendFactory.getContextBackend(this.conf); // init runtime context - runtimeContext = new JobRuntimeContext(streamingConfig); + runtimeContext = new JobMasterRuntimeContext(streamingConfig); + + // load checkpoint if is recover + if (Ray.getRuntimeContext().wasCurrentActorRestarted()) { + loadMasterCheckpoint(); + } LOG.info("Finished creating job master."); } + public static String getJobMasterRuntimeContextKey(StreamingMasterConfig conf) { + return conf.checkpointConfig.jobMasterContextCpPrefixKey() + conf.commonConfig.jobName(); + } + + private void loadMasterCheckpoint() { + LOG.info("Start to load JobMaster's checkpoint."); + // recover runtime context + byte[] bytes = + CheckpointStateUtil.get(contextBackend, getJobMasterRuntimeContextKey(getConf())); + if (bytes == null) { + LOG.warn("JobMaster got empty checkpoint from state backend. Skip loading checkpoint."); + // cp 0 was automatically saved when job started, see StreamTask. + runtimeContext.checkpointIds.add(0L); + return; + } + + this.runtimeContext = Serializer.decode(bytes); + + // FO case, triggered by ray, we need to register context when loading checkpoint + LOG.info("JobMaster recover runtime context[{}] from state backend.", runtimeContext); + init(true); + } + /** * Init JobMaster. To initiate or recover other components(like metrics and extra coordinators). * * @return init result */ - public Boolean init() { - LOG.info("Initializing job master."); + public Boolean init(boolean isRecover) { + LOG.info("Initializing job master, isRecover={}.", isRecover); if (this.runtimeContext.getExecutionGraph() == null) { LOG.error("Init job master failed. Job graphs is null."); @@ -60,6 +113,14 @@ public class JobMaster { ExecutionGraph executionGraph = graphManager.getExecutionGraph(); Preconditions.checkArgument(executionGraph != null, "no execution graph"); + // init coordinators + checkpointCoordinator = new CheckpointCoordinator(this); + checkpointCoordinator.start(); + failoverCoordinator = new FailoverCoordinator(this, isRecover); + failoverCoordinator.start(); + + saveContext(); + LOG.info("Finished initializing job master."); return true; } @@ -101,11 +162,86 @@ public class JobMaster { return true; } + public synchronized void saveContext() { + if (runtimeContext != null && getConf() != null) { + LOG.debug("Save JobMaster context."); + + byte[] contextBytes = Serializer.encode(runtimeContext); + CheckpointStateUtil + .put(contextBackend, getJobMasterRuntimeContextKey(getConf()), contextBytes); + } + } + + public byte[] reportJobWorkerCommit(byte[] reportBytes) { + Boolean ret = false; + RemoteCall.BaseWorkerCmd reportPb; + try { + reportPb = RemoteCall.BaseWorkerCmd.parseFrom(reportBytes); + ActorId actorId = ActorId.fromBytes(reportPb.getActorId().toByteArray()); + long remoteCallCost = System.currentTimeMillis() - reportPb.getTimestamp(); + LOG.info("Vertex {}, request job worker commit cost {}ms, actorId={}.", + getExecutionVertex(actorId), remoteCallCost, actorId); + RemoteCall.WorkerCommitReport commit = + reportPb.getDetail().unpack(RemoteCall.WorkerCommitReport.class); + WorkerCommitReport report = new WorkerCommitReport(actorId, commit.getCommitCheckpointId()); + ret = checkpointCoordinator.reportJobWorkerCommit(report); + } catch (InvalidProtocolBufferException e) { + LOG.error("Parse job worker commit has exception.", e); + } + return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray(); + } + + public byte[] requestJobWorkerRollback(byte[] requestBytes) { + Boolean ret = false; + RemoteCall.BaseWorkerCmd requestPb; + try { + requestPb = RemoteCall.BaseWorkerCmd.parseFrom(requestBytes); + ActorId actorId = ActorId.fromBytes(requestPb.getActorId().toByteArray()); + long remoteCallCost = System.currentTimeMillis() - requestPb.getTimestamp(); + ExecutionGraph executionGraph = graphManager.getExecutionGraph(); + Optional rayActor = executionGraph.getActorById(actorId); + if (!rayActor.isPresent()) { + LOG.warn("Skip this invalid rollback, actor id {} is not found.", actorId); + return RemoteCall.BoolResult.newBuilder().setBoolRes(false).build().toByteArray(); + } + ExecutionVertex exeVertex = getExecutionVertex(actorId); + LOG.info("Vertex {}, request job worker rollback cost {}ms, actorId={}.", + exeVertex, remoteCallCost, actorId); + RemoteCall.WorkerRollbackRequest rollbackPb + = RemoteCall.WorkerRollbackRequest.parseFrom(requestPb.getDetail().getValue()); + exeVertex.setPid(rollbackPb.getWorkerPid()); + // To find old container where slot is located in. + String hostname = ""; + Optional container = ResourceUtil.getContainerById( + resourceManager.getRegisteredContainers(), + exeVertex.getContainerId() + ); + if (container.isPresent()) { + hostname = container.get().getHostname(); + } + WorkerRollbackRequest request = new WorkerRollbackRequest( + actorId, rollbackPb.getExceptionMsg(), hostname, exeVertex.getPid() + ); + + ret = failoverCoordinator.requestJobWorkerRollback(request); + LOG.info("Vertex {} request rollback, exception msg : {}.", + exeVertex, rollbackPb.getExceptionMsg()); + + } catch (Throwable e) { + LOG.error("Parse job worker rollback has exception.", e); + } + return RemoteCall.BoolResult.newBuilder().setBoolRes(ret).build().toByteArray(); + } + + private ExecutionVertex getExecutionVertex(ActorId id) { + return graphManager.getExecutionGraph().getExecutionVertexByActorId(id); + } + public ActorHandle getJobMasterActor() { return jobMasterActor; } - public JobRuntimeContext getRuntimeContext() { + public JobMasterRuntimeContext getRuntimeContext() { return runtimeContext; } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java new file mode 100644 index 000000000..c9e6e8f57 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/context/JobMasterRuntimeContext.java @@ -0,0 +1,81 @@ +package io.ray.streaming.runtime.master.context; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.Sets; +import io.ray.streaming.jobgraph.JobGraph; +import io.ray.streaming.runtime.config.StreamingConfig; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; +import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +/** + * Runtime context for job master, which will be stored in backend when saving checkpoint. + * + *

Including: graph, resource, checkpoint info, etc. + */ +public class JobMasterRuntimeContext implements Serializable { + + /*--------------Checkpoint----------------*/ + public volatile List checkpointIds = new ArrayList<>(); + public volatile long lastCheckpointId = 0; + public volatile long lastCpTimestamp = 0; + public volatile BlockingQueue cpCmds = new LinkedBlockingQueue<>(); + /*--------------Failover----------------*/ + public volatile BlockingQueue foCmds = new ArrayBlockingQueue<>(8192); + public volatile Set unfinishedFoCmds = Sets.newConcurrentHashSet(); + private StreamingConfig conf; + private JobGraph jobGraph; + private volatile ExecutionGraph executionGraph; + + public JobMasterRuntimeContext(StreamingConfig conf) { + this.conf = conf; + } + + public String getJobName() { + return conf.masterConfig.commonConfig.jobName(); + } + + public StreamingConfig getConf() { + return conf; + } + + public JobGraph getJobGraph() { + return jobGraph; + } + + public void setJobGraph(JobGraph jobGraph) { + this.jobGraph = jobGraph; + } + + public ExecutionGraph getExecutionGraph() { + return executionGraph; + } + + public void setExecutionGraph(ExecutionGraph executionGraph) { + this.executionGraph = executionGraph; + } + + public Long getLastValidCheckpointId() { + if (checkpointIds.isEmpty()) { + // OL is invalid checkpoint id, worker will pass it + return 0L; + } + return checkpointIds.get(checkpointIds.size() - 1); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("jobGraph", jobGraph) + .add("executionGraph", executionGraph) + .add("conf", conf.getMap()) + .toString(); + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java new file mode 100644 index 000000000..ece4de4b7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/BaseCoordinator.java @@ -0,0 +1,44 @@ +package io.ray.streaming.runtime.master.coordinator; + +import io.ray.api.Ray; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.graphmanager.GraphManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class BaseCoordinator implements Runnable { + + private static final Logger LOG = LoggerFactory.getLogger(BaseCoordinator.class); + + protected final JobMaster jobMaster; + + protected final JobMasterRuntimeContext runtimeContext; + protected final GraphManager graphManager; + protected volatile boolean closed; + private Thread thread; + + public BaseCoordinator(JobMaster jobMaster) { + this.jobMaster = jobMaster; + this.runtimeContext = jobMaster.getRuntimeContext(); + this.graphManager = jobMaster.getGraphManager(); + } + + public void start() { + thread = new Thread(Ray.wrapRunnable(this), + this.getClass().getName() + "-" + System.currentTimeMillis()); + thread.start(); + } + + public void stop() { + closed = true; + + try { + if (thread != null) { + thread.join(30000); + } + } catch (InterruptedException e) { + LOG.error("Coordinator thread exit has exception.", e); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java new file mode 100644 index 000000000..862528776 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/CheckpointCoordinator.java @@ -0,0 +1,215 @@ +package io.ray.streaming.runtime.master.coordinator; + +import com.google.common.base.Preconditions; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.id.ActorId; +import io.ray.runtime.exception.RayException; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.rpc.RemoteCallWorker; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * CheckpointCoordinator is the controller of checkpoint, responsible for triggering checkpoint, + * collecting {@link JobWorker}'s reports and calling {@link JobWorker} to clear expired + * checkpoints when new checkpoint finished. + */ +public class CheckpointCoordinator extends BaseCoordinator { + + private static final Logger LOG = LoggerFactory.getLogger(CheckpointCoordinator.class); + private final Set pendingCheckpointActors = new HashSet<>(); + private final Set interruptedCheckpointSet = new HashSet<>(); + private final int cpIntervalSecs; + private final int cpTimeoutSecs; + + public CheckpointCoordinator(JobMaster jobMaster) { + super(jobMaster); + + // get checkpoint interval from conf + this.cpIntervalSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpIntervalSecs(); + this.cpTimeoutSecs = runtimeContext.getConf().masterConfig.checkpointConfig.cpTimeoutSecs(); + + // Trigger next checkpoint in interval by reset last checkpoint timestamp. + runtimeContext.lastCpTimestamp = System.currentTimeMillis(); + } + + @Override + public void run() { + while (!closed) { + try { + final BaseWorkerCmd command = runtimeContext.cpCmds.poll(1, TimeUnit.SECONDS); + if (command != null) { + if (command instanceof WorkerCommitReport) { + processCommitReport((WorkerCommitReport) command); + } else { + interruptCheckpoint(); + } + } + + if (!pendingCheckpointActors.isEmpty()) { + // if wait commit report timeout, this cp fail, and restart next cp + if (timeoutOnWaitCheckpoint()) { + LOG.warn("Waiting for checkpoint {} timeout, pending cp actors is {}.", + runtimeContext.lastCheckpointId, + graphManager.getExecutionGraph().getActorName(pendingCheckpointActors)); + + interruptCheckpoint(); + } + } else { + maybeTriggerCheckpoint(); + } + } catch (Throwable e) { + LOG.error("Checkpoint coordinator occur err.", e); + try { + interruptCheckpoint(); + } catch (Throwable interruptE) { + LOG.error("Ignore interrupt checkpoint exception in catch block."); + } + } + } + LOG.warn("Checkpoint coordinator thread exit."); + } + + public Boolean reportJobWorkerCommit(WorkerCommitReport report) { + LOG.info("Report job worker commit {}.", report); + + Boolean ret = runtimeContext.cpCmds.offer(report); + if (!ret) { + LOG.warn("Report job worker commit failed, because command queue is full."); + } + return ret; + } + + private void processCommitReport(WorkerCommitReport commitReport) { + LOG.info("Start process commit report {}, from actor name={}.", commitReport, + graphManager.getExecutionGraph().getActorName(commitReport.fromActorId)); + + try { + Preconditions.checkArgument( + commitReport.commitCheckpointId == runtimeContext.lastCheckpointId, + "expect checkpointId %s, but got %s", + runtimeContext.lastCheckpointId, commitReport); + + if (!pendingCheckpointActors.contains(commitReport.fromActorId)) { + LOG.warn("Invalid commit report, skipped."); + return; + } + + pendingCheckpointActors.remove(commitReport.fromActorId); + LOG.info("Pending actors after this commit: {}.", + graphManager.getExecutionGraph().getActorName(pendingCheckpointActors)); + + // checkpoint finish + if (pendingCheckpointActors.isEmpty()) { + // actor finish + runtimeContext.checkpointIds.add(runtimeContext.lastCheckpointId); + + if (clearExpiredCpStateAndQueueMsg()) { + // save master context + jobMaster.saveContext(); + + LOG.info("Finish checkpoint: {}.", runtimeContext.lastCheckpointId); + } else { + LOG.warn("Fail to do checkpoint: {}.", runtimeContext.lastCheckpointId); + } + } + + LOG.info("Process commit report {} success.", commitReport); + } catch (Throwable e) { + LOG.warn("Process commit report has exception.", e); + } + } + + private void triggerCheckpoint() { + interruptedCheckpointSet.clear(); + if (LOG.isInfoEnabled()) { + LOG.info("Start trigger checkpoint {}.", runtimeContext.lastCheckpointId + 1); + } + + List allIds = graphManager.getExecutionGraph().getAllActorsId(); + // do the checkpoint + pendingCheckpointActors.addAll(allIds); + + // inc last checkpoint id + ++runtimeContext.lastCheckpointId; + + final List sourcesRet = new ArrayList<>(); + + graphManager.getExecutionGraph().getSourceActors().forEach(actor -> { + sourcesRet.add(RemoteCallWorker.triggerCheckpoint( + actor, runtimeContext.lastCheckpointId)); + }); + + for (ObjectRef rayObject : sourcesRet) { + if (rayObject.get() instanceof RayException) { + LOG.warn("Trigger checkpoint has exception.", (RayException) rayObject.get()); + throw (RayException) rayObject.get(); + } + } + runtimeContext.lastCpTimestamp = System.currentTimeMillis(); + LOG.info("Trigger checkpoint success."); + } + + private void interruptCheckpoint() { + // notify checkpoint timeout is time-consuming while many workers crash or + // container failover. + if (interruptedCheckpointSet.contains(runtimeContext.lastCheckpointId)) { + LOG.warn("Skip interrupt duplicated checkpoint id : {}.", runtimeContext.lastCheckpointId); + return; + } + interruptedCheckpointSet.add(runtimeContext.lastCheckpointId); + LOG.warn("Interrupt checkpoint, checkpoint id : {}.", runtimeContext.lastCheckpointId); + + List allActor = graphManager.getExecutionGraph().getAllActors(); + if (runtimeContext.lastCheckpointId > runtimeContext.getLastValidCheckpointId()) { + RemoteCallWorker + .notifyCheckpointTimeoutParallel(allActor, runtimeContext.lastCheckpointId); + } + + if (!pendingCheckpointActors.isEmpty()) { + pendingCheckpointActors.clear(); + } + maybeTriggerCheckpoint(); + } + + private void maybeTriggerCheckpoint() { + if (readyToTrigger()) { + triggerCheckpoint(); + } + } + + private boolean clearExpiredCpStateAndQueueMsg() { + // queue msg must clear when first checkpoint finish + List allActor = graphManager.getExecutionGraph().getAllActors(); + if (1 == runtimeContext.checkpointIds.size()) { + Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0); + RemoteCallWorker.clearExpiredCheckpointParallel(allActor, 0L, msgExpiredCheckpointId); + } + + if (runtimeContext.checkpointIds.size() > 1) { + Long stateExpiredCpId = runtimeContext.checkpointIds.remove(0); + Long msgExpiredCheckpointId = runtimeContext.checkpointIds.get(0); + RemoteCallWorker + .clearExpiredCheckpointParallel(allActor, stateExpiredCpId, msgExpiredCheckpointId); + } + return true; + } + + private boolean readyToTrigger() { + return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= + cpIntervalSecs * 1000; + } + + private boolean timeoutOnWaitCheckpoint() { + return (System.currentTimeMillis() - runtimeContext.lastCpTimestamp) >= cpTimeoutSecs * 1000; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java new file mode 100644 index 000000000..c58c84d6a --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/FailoverCoordinator.java @@ -0,0 +1,281 @@ +package io.ray.streaming.runtime.master.coordinator; + +import io.ray.api.BaseActorHandle; +import io.ray.api.id.ActorId; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import io.ray.streaming.runtime.master.coordinator.command.BaseWorkerCmd; +import io.ray.streaming.runtime.master.coordinator.command.InterruptCheckpointRequest; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; +import io.ray.streaming.runtime.rpc.async.AsyncRemoteCaller; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.util.ResourceUtil; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.apache.commons.collections.map.DefaultedMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FailoverCoordinator extends BaseCoordinator { + + private static final Logger LOG = LoggerFactory.getLogger(FailoverCoordinator.class); + + private static final int ROLLBACK_RETRY_TIME_MS = 10 * 1000; + private final Object cmdLock = new Object(); + private final AsyncRemoteCaller asyncRemoteCaller; + private long currentCascadingGroupId = 0; + private final Map isRollbacking = + DefaultedMap.decorate(new ConcurrentHashMap(), false); + + public FailoverCoordinator(JobMaster jobMaster, boolean isRecover) { + this(jobMaster, new AsyncRemoteCaller(), isRecover); + } + + public FailoverCoordinator( + JobMaster jobMaster, AsyncRemoteCaller asyncRemoteCaller, + boolean isRecover) { + super(jobMaster); + + this.asyncRemoteCaller = asyncRemoteCaller; + // recover unfinished FO commands + JobMasterRuntimeContext runtimeContext = jobMaster.getRuntimeContext(); + if (isRecover) { + runtimeContext.foCmds.addAll(runtimeContext.unfinishedFoCmds); + } + runtimeContext.unfinishedFoCmds.clear(); + } + + @Override + public void run() { + while (!closed) { + try { + final BaseWorkerCmd command; + // see rollback() for lock reason + synchronized (cmdLock) { + command = jobMaster.getRuntimeContext().foCmds.poll(1, TimeUnit.SECONDS); + } + if (null == command) { + continue; + } + if (command instanceof WorkerRollbackRequest) { + jobMaster.getRuntimeContext().unfinishedFoCmds.add(command); + dealWithRollbackRequest((WorkerRollbackRequest) command); + } + } catch (Throwable e) { + LOG.error("Fo coordinator occur err.", e); + } + } + LOG.warn("Fo coordinator thread exit."); + } + + private Boolean isDuplicateRequest(WorkerRollbackRequest request) { + try { + Object[] foCmdsArray = runtimeContext.foCmds.toArray(); + for (Object cmd : foCmdsArray) { + if (request.fromActorId.equals(((BaseWorkerCmd) cmd).fromActorId)) { + return true; + } + } + } catch (Exception e) { + LOG.warn("Check request is duplicated failed.", e); + } + return false; + } + + public Boolean requestJobWorkerRollback(WorkerRollbackRequest request) { + LOG.info("Request job worker rollback {}.", request); + boolean ret; + if (!isDuplicateRequest(request)) { + ret = runtimeContext.foCmds.offer(request); + } else { + LOG.warn("Skip duplicated worker rollback request, {}.", request.toString()); + return true; + } + jobMaster.saveContext(); + if (!ret) { + LOG.warn("Request job worker rollback failed, because command queue is full."); + } + return ret; + } + + private void dealWithRollbackRequest(WorkerRollbackRequest rollbackRequest) { + LOG.info("Start deal with rollback request {}.", rollbackRequest); + + ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest); + + // Reset pid for new-rollback actor. + if (null != rollbackRequest.getPid() && + !rollbackRequest.getPid().equals(WorkerRollbackRequest.DEFAULT_PID)) { + exeVertex.setPid(rollbackRequest.getPid()); + } + + if (isRollbacking.get(exeVertex)) { + LOG.info("Vertex {} is rollbacking, skip rollback again.", exeVertex); + return; + } + + String hostname = ""; + Optional container = ResourceUtil.getContainerById( + jobMaster.getResourceManager().getRegisteredContainers(), + exeVertex.getContainerId() + ); + if (container.isPresent()) { + hostname = container.get().getHostname(); + } + + if (rollbackRequest.isForcedRollback) { + interruptCheckpointAndRollback(rollbackRequest); + } else { + asyncRemoteCaller.checkIfNeedRollbackAsync(exeVertex.getWorkerActor(), res -> { + if (!res) { + LOG.info("Vertex {} doesn't need to rollback, skip it.", exeVertex); + return; + } + interruptCheckpointAndRollback(rollbackRequest); + }, throwable -> { + LOG.error("Exception when calling checkIfNeedRollbackAsync, maybe vertex is dead" + + ", ignore this request, vertex={}.", exeVertex, throwable); + }); + } + + LOG.info("Deal with rollback request {} success.", rollbackRequest); + } + + private void interruptCheckpointAndRollback(WorkerRollbackRequest rollbackRequest) { + // assign a cascadingGroupId + if (rollbackRequest.cascadingGroupId == null) { + rollbackRequest.cascadingGroupId = currentCascadingGroupId++; + } + // get last valid checkpoint id then call worker rollback + rollback(jobMaster.getRuntimeContext().getLastValidCheckpointId(), rollbackRequest, + currentCascadingGroupId); + // we interrupt current checkpoint for 2 considerations: + // 1. current checkpoint might be timeout, because barrier might be lost after failover. so we + // interrupt current checkpoint to avoid waiting. + // 2. when we want to rollback vertex to n, job finished checkpoint n+1 and cleared state + // of checkpoint n. + jobMaster.getRuntimeContext().cpCmds.offer(new InterruptCheckpointRequest()); + } + + /** + * call worker rollback, and deal with it's reports. callback won't be finished until + * the entire DAG back to normal. + * + * @param checkpointId checkpointId to be rollback + * @param rollbackRequest worker rollback request + * @param cascadingGroupId all rollback of a cascading group should have same ID + */ + private void rollback( + long checkpointId, WorkerRollbackRequest rollbackRequest, + long cascadingGroupId) { + ExecutionVertex exeVertex = getExeVertexFromRequest(rollbackRequest); + LOG.info("Call vertex {} to rollback, checkpoint id is {}, cascadingGroupId={}.", + exeVertex, checkpointId, cascadingGroupId); + + isRollbacking.put(exeVertex, true); + + asyncRemoteCaller.rollback(exeVertex.getWorkerActor(), checkpointId, result -> { + List newRollbackRequests = new ArrayList<>(); + switch (result.getResultEnum()) { + case SUCCESS: + ChannelRecoverInfo recoverInfo = result.getResultObj(); + LOG.info("Vertex {} rollback done, dataLostQueues={}, msg={}, cascadingGroupId={}.", + exeVertex, recoverInfo.getDataLostQueues(), result.getResultMsg(), cascadingGroupId); + // rollback upstream if vertex reports abnormal input queues + newRollbackRequests = + cascadeUpstreamActors(recoverInfo.getDataLostQueues(), exeVertex, cascadingGroupId); + break; + case SKIPPED: + LOG.info("Vertex skip rollback, result = {}, cascadingGroupId={}.", result, + cascadingGroupId); + break; + default: + LOG.error( + "Rollback vertex {} failed, result={}, cascadingGroupId={}," + + " rollback this worker again after {} ms.", + exeVertex, result, cascadingGroupId, ROLLBACK_RETRY_TIME_MS); + Thread.sleep(ROLLBACK_RETRY_TIME_MS); + LOG.info("Add rollback request for {} again, cascadingGroupId={}.", exeVertex, + cascadingGroupId); + newRollbackRequests.add( + new WorkerRollbackRequest(exeVertex, "", "Rollback failed, try again.", false) + ); + break; + } + + // lock to avoid executing new rollback requests added. + // consider such a case: A->B->C, C cascade B, and B cascade A + // if B is rollback before B's rollback request is saved, and then JobMaster crashed, + // then A will never be rollback. + synchronized (cmdLock) { + jobMaster.getRuntimeContext().foCmds.addAll(newRollbackRequests); + // this rollback request is finished, remove it. + jobMaster.getRuntimeContext().unfinishedFoCmds.remove(rollbackRequest); + jobMaster.saveContext(); + } + isRollbacking.put(exeVertex, false); + }, throwable -> { + LOG.error("Exception when calling vertex to rollback, vertex={}.", exeVertex, throwable); + isRollbacking.put(exeVertex, false); + }); + + LOG.info("Finish rollback vertex {}, checkpoint id is {}.", exeVertex, checkpointId); + } + + private List cascadeUpstreamActors( + Set dataLostQueues, ExecutionVertex fromVertex, long cascadingGroupId) { + List cascadedRollbackRequest = new ArrayList<>(); + // rollback upstream if vertex reports abnormal input queues + dataLostQueues.forEach(q -> { + BaseActorHandle upstreamActor = + graphManager.getExecutionGraph().getPeerActor(fromVertex.getWorkerActor(), q); + ExecutionVertex upstreamExeVertex = getExecutionVertex(upstreamActor); + // vertexes that has already cascaded by other vertex in the same level + // of graph should be ignored. + if (isRollbacking.get(upstreamExeVertex)) { + return; + } + LOG.info("Call upstream vertex {} of vertex {} to rollback, cascadingGroupId={}.", + upstreamExeVertex, fromVertex, cascadingGroupId); + String hostname = ""; + Optional container = ResourceUtil.getContainerById( + jobMaster.getResourceManager().getRegisteredContainers(), + upstreamExeVertex.getContainerId() + ); + if (container.isPresent()) { + hostname = container.get().getHostname(); + } + // force upstream vertexes to rollback + WorkerRollbackRequest upstreamRequest = new WorkerRollbackRequest( + upstreamExeVertex, hostname, String.format("Cascading rollback from %s", fromVertex), true + ); + upstreamRequest.cascadingGroupId = cascadingGroupId; + cascadedRollbackRequest.add(upstreamRequest); + }); + return cascadedRollbackRequest; + } + + private ExecutionVertex getExeVertexFromRequest(WorkerRollbackRequest rollbackRequest) { + ActorId actorId = rollbackRequest.fromActorId; + Optional rayActor = graphManager.getExecutionGraph().getActorById(actorId); + if (!rayActor.isPresent()) { + throw new RuntimeException("Can not find ray actor of ID " + actorId); + } + return getExecutionVertex(rollbackRequest.fromActorId); + } + + private ExecutionVertex getExecutionVertex(BaseActorHandle actor) { + return graphManager.getExecutionGraph().getExecutionVertexByActorId(actor.getId()); + } + + private ExecutionVertex getExecutionVertex(ActorId actorId) { + return graphManager.getExecutionGraph().getExecutionVertexByActorId(actorId); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java new file mode 100644 index 000000000..2c6a9322d --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/BaseWorkerCmd.java @@ -0,0 +1,17 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +import io.ray.api.id.ActorId; +import java.io.Serializable; + +public abstract class BaseWorkerCmd implements Serializable { + + public ActorId fromActorId; + + public BaseWorkerCmd() { + } + + protected BaseWorkerCmd(ActorId actorId) { + this.fromActorId = actorId; + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java new file mode 100644 index 000000000..29a46ab10 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/InterruptCheckpointRequest.java @@ -0,0 +1,5 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +public final class InterruptCheckpointRequest extends BaseWorkerCmd { + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java new file mode 100644 index 000000000..7750ce1b0 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerCommitReport.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +import com.google.common.base.MoreObjects; +import io.ray.api.id.ActorId; + +public final class WorkerCommitReport extends BaseWorkerCmd { + + public final long commitCheckpointId; + + public WorkerCommitReport(ActorId actorId, long commitCheckpointId) { + super(actorId); + this.commitCheckpointId = commitCheckpointId; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("commitCheckpointId", commitCheckpointId) + .add("fromActorId", fromActorId) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java new file mode 100644 index 000000000..e56518382 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/coordinator/command/WorkerRollbackRequest.java @@ -0,0 +1,63 @@ +package io.ray.streaming.runtime.master.coordinator.command; + +import com.google.common.base.MoreObjects; +import io.ray.api.id.ActorId; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; + +public final class WorkerRollbackRequest extends BaseWorkerCmd { + + public static String DEFAULT_PID = "UNKNOWN_PID"; + public Long cascadingGroupId = null; + public boolean isForcedRollback = false; + private String exceptionMsg = "No detail message."; + private String hostname = "UNKNOWN_HOST"; + private String pid = DEFAULT_PID; + + public WorkerRollbackRequest(ActorId actorId) { + super(actorId); + } + + public WorkerRollbackRequest(ActorId actorId, String msg) { + super(actorId); + exceptionMsg = msg; + } + + public WorkerRollbackRequest( + ExecutionVertex executionVertex, + String hostname, + String msg, + boolean isForcedRollback) { + + super(executionVertex.getWorkerActorId()); + + this.hostname = hostname; + this.pid = executionVertex.getPid(); + this.exceptionMsg = msg; + this.isForcedRollback = isForcedRollback; + } + + public WorkerRollbackRequest(ActorId actorId, String msg, String hostname, String pid) { + this(actorId, msg); + this.hostname = hostname; + this.pid = pid; + } + + public String getRollbackExceptionMsg() { + return exceptionMsg; + } + + public String getHostname() { + return hostname; + } + + public String getPid() { + return pid; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("fromActorId", fromActorId) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java index e76963a47..a977967ff 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/graphmanager/GraphManagerImpl.java @@ -1,14 +1,19 @@ package io.ray.streaming.runtime.master.graphmanager; +import io.ray.api.BaseActorHandle; import io.ray.streaming.jobgraph.JobGraph; import io.ray.streaming.jobgraph.JobVertex; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobEdge; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; -import io.ray.streaming.runtime.master.JobRuntimeContext; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -16,9 +21,9 @@ public class GraphManagerImpl implements GraphManager { private static final Logger LOG = LoggerFactory.getLogger(GraphManagerImpl.class); - protected final JobRuntimeContext runtimeContext; + protected final JobMasterRuntimeContext runtimeContext; - public GraphManagerImpl(JobRuntimeContext runtimeContext) { + public GraphManagerImpl(JobMasterRuntimeContext runtimeContext) { this.runtimeContext = runtimeContext; } @@ -48,6 +53,7 @@ public class GraphManagerImpl implements GraphManager { // create vertex Map exeJobVertexMap = new LinkedHashMap<>(); + Map executionVertexMap = new HashMap<>(); long buildTime = executionGraph.getBuildTime(); for (JobVertex jobVertex : jobGraph.getJobVertices()) { int jobVertexId = jobVertex.getVertexId(); @@ -59,32 +65,47 @@ public class GraphManagerImpl implements GraphManager { buildTime)); } - // connect vertex + // for each job edge, connect all source exeVertices and target exeVertices jobGraph.getJobEdges().forEach(jobEdge -> { ExecutionJobVertex source = exeJobVertexMap.get(jobEdge.getSrcVertexId()); ExecutionJobVertex target = exeJobVertexMap.get(jobEdge.getTargetVertexId()); - ExecutionJobEdge executionJobEdge = - new ExecutionJobEdge(source, target, jobEdge); + ExecutionJobEdge executionJobEdge = new ExecutionJobEdge(source, target, jobEdge); source.getOutputEdges().add(executionJobEdge); target.getInputEdges().add(executionJobEdge); - source.getExecutionVertices().forEach(vertex -> { - target.getExecutionVertices().forEach(outputVertex -> { - ExecutionEdge executionEdge = new ExecutionEdge(vertex, outputVertex, executionJobEdge); - vertex.getOutputEdges().add(executionEdge); - outputVertex.getInputEdges().add(executionEdge); + source.getExecutionVertices().forEach(sourceExeVertex -> { + target.getExecutionVertices().forEach(targetExeVertex -> { + // pre-process some mappings + executionVertexMap.put(targetExeVertex.getExecutionVertexId(), targetExeVertex); + executionVertexMap.put(sourceExeVertex.getExecutionVertexId(), sourceExeVertex); + // build execution edge + ExecutionEdge executionEdge = + new ExecutionEdge(sourceExeVertex, targetExeVertex, executionJobEdge); + sourceExeVertex.getOutputEdges().add(executionEdge); + targetExeVertex.getInputEdges().add(executionEdge); }); }); }); // set execution job vertex into execution graph executionGraph.setExecutionJobVertexMap(exeJobVertexMap); + executionGraph.setExecutionVertexMap(executionVertexMap); return executionGraph; } + private void addActorToChannelGroupedActors( + Map> channelGroupedActors, + String channelId, + BaseActorHandle actor) { + + Set actorSet = + channelGroupedActors.computeIfAbsent(channelId, k -> new HashSet<>()); + actorSet.add(actor); + } + @Override public JobGraph getJobGraph() { return runtimeContext.getJobGraph(); diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java index 3b7b35ba6..2e59fed09 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerImpl.java @@ -11,7 +11,7 @@ import io.ray.streaming.runtime.config.types.ResourceAssignStrategyType; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.resource.Container; import io.ray.streaming.runtime.core.resource.Resources; -import io.ray.streaming.runtime.master.JobRuntimeContext; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategy; import io.ray.streaming.runtime.master.resourcemanager.strategy.ResourceAssignStrategyFactory; import io.ray.streaming.runtime.util.RayUtils; @@ -30,39 +30,33 @@ public class ResourceManagerImpl implements ResourceManager { //Container used tag private static final String CONTAINER_ENGAGED_KEY = "CONTAINER_ENGAGED_KEY"; - - /** - * Job runtime context. - */ - private JobRuntimeContext runtimeContext; - - /** - * Resource related configuration. - */ - private ResourceConfig resourceConfig; - - /** - * Slot assign strategy. - */ - private ResourceAssignStrategy resourceAssignStrategy; - /** * Resource description information. */ private final Resources resources; - - /** - * Customized actor number for each container - */ - private int actorNumPerContainer; - /** * Timing resource updating thread */ private final ScheduledExecutorService resourceUpdater = new ScheduledThreadPoolExecutor(1, new ThreadFactoryBuilder().setNameFormat("resource-update-thread").build()); + /** + * Job runtime context. + */ + private JobMasterRuntimeContext runtimeContext; + /** + * Resource related configuration. + */ + private ResourceConfig resourceConfig; + /** + * Slot assign strategy. + */ + private ResourceAssignStrategy resourceAssignStrategy; + /** + * Customized actor number for each container + */ + private int actorNumPerContainer; - public ResourceManagerImpl(JobRuntimeContext runtimeContext) { + public ResourceManagerImpl(JobMasterRuntimeContext runtimeContext) { this.runtimeContext = runtimeContext; StreamingMasterConfig masterConfig = runtimeContext.getConf().masterConfig; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java index 6b9b3a690..238fdf6f7 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/JobSchedulerImpl.java @@ -23,20 +23,18 @@ import org.slf4j.LoggerFactory; public class JobSchedulerImpl implements JobScheduler { private static final Logger LOG = LoggerFactory.getLogger(JobSchedulerImpl.class); - - private StreamingConfig jobConf; - private final JobMaster jobMaster; private final ResourceManager resourceManager; private final GraphManager graphManager; private final WorkerLifecycleController workerLifecycleController; + private StreamingConfig jobConfig; public JobSchedulerImpl(JobMaster jobMaster) { this.jobMaster = jobMaster; this.graphManager = jobMaster.getGraphManager(); this.resourceManager = jobMaster.getResourceManager(); this.workerLifecycleController = new WorkerLifecycleController(); - this.jobConf = jobMaster.getRuntimeContext().getConf(); + this.jobConfig = jobMaster.getRuntimeContext().getConf(); LOG.info("Scheduler initiated."); } @@ -46,8 +44,13 @@ public class JobSchedulerImpl implements JobScheduler { LOG.info("Begin scheduling. Job: {}.", executionGraph.getJobName()); // Allocate resource then create workers + // Actor creation is in this step prepareResourceAndCreateWorker(executionGraph); + // now actor info is available in execution graph + // preprocess some handy mappings in execution graph + executionGraph.generateActorMappings(); + // init worker context and start to run initAndStart(executionGraph); @@ -87,7 +90,7 @@ public class JobSchedulerImpl implements JobScheduler { initMaster(); // start workers - startWorkers(executionGraph); + startWorkers(executionGraph, jobMaster.getRuntimeContext().lastCheckpointId); } /** @@ -122,7 +125,7 @@ public class JobSchedulerImpl implements JobScheduler { boolean result; try { result = workerLifecycleController.initWorkers(vertexToContextMap, - jobConf.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs()); + jobConfig.masterConfig.schedulerConfig.workerInitiationWaitTimeoutMs()); } catch (Exception e) { LOG.error("Failed to initiate workers.", e); return false; @@ -133,11 +136,12 @@ public class JobSchedulerImpl implements JobScheduler { /** * Start JobWorkers according to the physical plan. */ - public boolean startWorkers(ExecutionGraph executionGraph) { + public boolean startWorkers(ExecutionGraph executionGraph, long checkpointId) { boolean result; try { result = workerLifecycleController.startWorkers( - executionGraph, jobConf.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs()); + executionGraph, checkpointId, + jobConfig.masterConfig.schedulerConfig.workerStartingWaitTimeoutMs()); } catch (Exception e) { LOG.error("Failed to start workers.", e); return false; @@ -194,7 +198,7 @@ public class JobSchedulerImpl implements JobScheduler { } private void initMaster() { - jobMaster.init(); + jobMaster.init(false); } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java index 876e9f924..bc8b462c7 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/master/scheduler/controller/WorkerLifecycleController.java @@ -9,6 +9,8 @@ import io.ray.api.id.ActorId; import io.ray.streaming.api.Language; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.python.GraphPbBuilder; import io.ray.streaming.runtime.rpc.RemoteCallWorker; import io.ray.streaming.runtime.worker.JobWorker; import io.ray.streaming.runtime.worker.context.JobWorkerContext; @@ -40,20 +42,23 @@ public class WorkerLifecycleController { * @return creation result */ private boolean createWorker(ExecutionVertex executionVertex) { - LOG.info("Start to create worker actor for vertex: {} with resource: {}.", - executionVertex.getExecutionVertexName(), executionVertex.getResource()); + LOG.info("Start to create worker actor for vertex: {} with resource: {}, workeConfig: {}.", + executionVertex.getExecutionVertexName(), executionVertex.getResource(), + executionVertex.getWorkerConfig()); Language language = executionVertex.getLanguage(); BaseActorHandle actor; if (Language.JAVA == language) { - actor = Ray.actor(JobWorker::new) + actor = Ray.actor(JobWorker::new, executionVertex) .setResources(executionVertex.getResource()) .setMaxRestarts(-1) .remote(); } else { + RemoteCall.ExecutionVertexContext.ExecutionVertex vertexPb + = new GraphPbBuilder().buildVertex(executionVertex); actor = Ray.actor( - PyActorClass.of("ray.streaming.runtime.worker", "JobWorker")) + PyActorClass.of("ray.streaming.runtime.worker", "JobWorker"), vertexPb.toByteArray()) .setResources(executionVertex.getResource()) .setMaxRestarts(-1) .remote(); @@ -111,20 +116,20 @@ public class WorkerLifecycleController { * @param timeout timeout for waiting, unit: ms * @return starting result */ - public boolean startWorkers(ExecutionGraph executionGraph, int timeout) { + public boolean startWorkers(ExecutionGraph executionGraph, long lastCheckpointId, int timeout) { LOG.info("Begin starting workers."); long startTime = System.currentTimeMillis(); - List> objectRefs = new ArrayList<>(); + List> objectRefs = new ArrayList<>(); // start source actors 1st executionGraph.getSourceActors() - .forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor))); + .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId))); // then start non-source actors executionGraph.getNonSourceActors() - .forEach(actor -> objectRefs.add(RemoteCallWorker.startWorker(actor))); + .forEach(actor -> objectRefs.add(RemoteCallWorker.rollback(actor, lastCheckpointId))); - WaitResult result = Ray.wait(objectRefs, objectRefs.size(), timeout); + WaitResult result = Ray.wait(objectRefs, objectRefs.size(), timeout); if (result.getReady().size() != objectRefs.size()) { LOG.error("Starting workers timeout[{} ms].", timeout); return false; diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java new file mode 100644 index 000000000..5cdba0b0a --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/message/CallResult.java @@ -0,0 +1,122 @@ +package io.ray.streaming.runtime.message; + +import com.google.common.base.MoreObjects; +import java.io.Serializable; + +public class CallResult implements Serializable { + + protected T resultObj; + private boolean success; + private int resultCode; + private String resultMsg; + + public CallResult() { + } + + public CallResult(boolean success, int resultCode, String resultMsg, T resultObj) { + this.success = success; + this.resultCode = resultCode; + this.resultMsg = resultMsg; + this.resultObj = resultObj; + } + + public static CallResult success() { + return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, null); + } + + public static CallResult success(T payload) { + return new CallResult<>(true, CallResultEnum.SUCCESS.code, CallResultEnum.SUCCESS.msg, payload); + } + + public static CallResult skipped(String msg) { + return new CallResult<>(true, CallResultEnum.SKIPPED.code, msg, null); + } + + public static CallResult fail() { + return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, null); + } + + public static CallResult fail(T payload) { + return new CallResult<>(false, CallResultEnum.FAILED.code, CallResultEnum.FAILED.msg, payload); + } + + public static CallResult fail(String msg) { + return new CallResult<>(false, CallResultEnum.FAILED.code, msg, null); + } + + public static CallResult fail(CallResultEnum resultEnum, T payload) { + return new CallResult<>(false, resultEnum.code, resultEnum.msg, payload); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("resultObj", resultObj) + .add("success", success) + .add("resultCode", resultCode) + .add("resultMsg", resultMsg) + .toString(); + } + + public boolean isSuccess() { + return this.success; + } + + public void setSuccess(boolean success) { + this.success = success; + } + + public int getResultCode() { + return this.resultCode; + } + + public void setResultCode(int resultCode) { + this.resultCode = resultCode; + } + + public CallResultEnum getResultEnum() { + return CallResultEnum.getEnum(this.resultCode); + } + + public String getResultMsg() { + return this.resultMsg; + } + + public void setResultMsg(String resultMsg) { + this.resultMsg = resultMsg; + } + + public T getResultObj() { + return this.resultObj; + } + + public void setResultObj(T resultObj) { + this.resultObj = resultObj; + } + + public enum CallResultEnum implements Serializable { + /** + * call result enum + */ + SUCCESS(0, "SUCCESS"), + FAILED(1, "FAILED"), + SKIPPED(2, "SKIPPED"); + + public final int code; + public final String msg; + + CallResultEnum(int code, String msg) { + this.code = code; + this.msg = msg; + } + + public static CallResultEnum getEnum(int code) { + for (CallResultEnum value : CallResultEnum.values()) { + if (code == value.code) { + return value; + } + } + return FAILED; + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java index 408397ebb..f7bbc6278 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/python/GraphPbBuilder.java @@ -65,7 +65,7 @@ public class GraphPbBuilder { return builder.build(); } - private RemoteCall.ExecutionVertexContext.ExecutionVertex buildVertex( + public RemoteCall.ExecutionVertexContext.ExecutionVertex buildVertex( ExecutionVertex executionVertex) { // build vertex infos RemoteCall.ExecutionVertexContext.ExecutionVertex.Builder executionVertexBuilder = @@ -79,9 +79,11 @@ public class GraphPbBuilder { ByteString.copyFrom( serializeOperator(executionVertex.getStreamOperator()))); executionVertexBuilder.setChained(isPythonChainedOperator(executionVertex.getStreamOperator())); - executionVertexBuilder.setWorkerActor( - ByteString.copyFrom( - ((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes())); + if (executionVertex.getWorkerActor() != null) { + executionVertexBuilder.setWorkerActor( + ByteString.copyFrom( + ((NativeActorHandle) (executionVertex.getWorkerActor())).toBytes())); + } executionVertexBuilder.setContainerId(executionVertex.getContainerId().toString()); executionVertexBuilder.setBuildTime(executionVertex.getBuildTime()); executionVertexBuilder.setLanguage( diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java new file mode 100644 index 000000000..c0bbcc2c2 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/PbResultParser.java @@ -0,0 +1,54 @@ +package io.ray.streaming.runtime.rpc; + +import com.google.protobuf.InvalidProtocolBufferException; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.message.CallResult; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class PbResultParser { + + private static final Logger LOG = LoggerFactory.getLogger(PbResultParser.class); + + public static Boolean parseBoolResult(byte[] result) { + if (null == result) { + LOG.warn("Result is null."); + return false; + } + + RemoteCall.BoolResult boolResult; + try { + boolResult = RemoteCall.BoolResult.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + LOG.error("Parse boolean result has exception.", e); + return false; + } + + return boolResult.getBoolRes(); + } + + public static CallResult parseRollbackResult(byte[] bytes) { + RemoteCall.CallResult callResultPb; + try { + callResultPb = RemoteCall.CallResult.parseFrom(bytes); + } catch (InvalidProtocolBufferException e) { + LOG.error("Rollback parse result has exception.", e); + return CallResult.fail(); + } + + CallResult callResult = new CallResult<>(); + callResult.setSuccess(callResultPb.getSuccess()); + callResult.setResultCode(callResultPb.getResultCode()); + callResult.setResultMsg(callResultPb.getResultMsg()); + RemoteCall.QueueRecoverInfo recoverInfo = callResultPb.getResultObj(); + Map creationStatusMap = new HashMap<>(); + recoverInfo.getCreationStatusMap().forEach((k, v) -> { + creationStatusMap.put(k, ChannelRecoverInfo.ChannelCreationStatus.fromInt(v.getNumber())); + }); + callResult.setResultObj(new ChannelRecoverInfo(creationStatusMap)); + return callResult; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java new file mode 100644 index 000000000..fe25002bf --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallMaster.java @@ -0,0 +1,46 @@ +package io.ray.streaming.runtime.rpc; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; + +public class RemoteCallMaster { + + public static ObjectRef reportJobWorkerCommitAsync( + ActorHandle actor, + WorkerCommitReport commitReport) { + RemoteCall.WorkerCommitReport commit = RemoteCall.WorkerCommitReport.newBuilder() + .setCommitCheckpointId(commitReport.commitCheckpointId) + .build(); + Any detail = Any.pack(commit); + RemoteCall.BaseWorkerCmd cmd = RemoteCall.BaseWorkerCmd.newBuilder() + .setActorId(ByteString.copyFrom(commitReport.fromActorId.getBytes())) + .setTimestamp(System.currentTimeMillis()) + .setDetail(detail).build(); + + return actor.task(JobMaster::reportJobWorkerCommit, cmd.toByteArray()).remote(); + } + + public static Boolean requestJobWorkerRollback( + ActorHandle actor, + WorkerRollbackRequest rollbackRequest) { + RemoteCall.WorkerRollbackRequest request = RemoteCall.WorkerRollbackRequest.newBuilder() + .setExceptionMsg(rollbackRequest.getRollbackExceptionMsg()) + .setWorkerHostname(rollbackRequest.getHostname()) + .setWorkerPid(rollbackRequest.getPid()).build(); + Any detail = Any.pack(request); + RemoteCall.BaseWorkerCmd cmd = RemoteCall.BaseWorkerCmd.newBuilder() + .setActorId(ByteString.copyFrom(rollbackRequest.fromActorId.getBytes())) + .setTimestamp(System.currentTimeMillis()) + .setDetail(detail).build(); + ObjectRef ret = actor.task( + JobMaster::requestJobWorkerRollback, cmd.toByteArray()).remote(); + byte[] res = ret.get(); + return PbResultParser.parseBoolResult(res); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java index a12dfaea4..d9b373370 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/RemoteCallWorker.java @@ -4,10 +4,15 @@ import io.ray.api.ActorHandle; import io.ray.api.BaseActorHandle; import io.ray.api.ObjectRef; import io.ray.api.PyActorHandle; +import io.ray.api.Ray; import io.ray.api.function.PyActorMethod; +import io.ray.api.function.RayFunc3; +import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.master.JobMaster; import io.ray.streaming.runtime.worker.JobWorker; import io.ray.streaming.runtime.worker.context.JobWorkerContext; +import java.util.ArrayList; +import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,19 +51,26 @@ public class RemoteCallWorker { * Call JobWorker actor to start. * * @param actor target JobWorker actor + * @param checkpointId checkpoint ID to be rollback * @return start result */ - public static ObjectRef startWorker(BaseActorHandle actor) { + public static ObjectRef rollback(BaseActorHandle actor, final Long checkpointId) { LOG.info("Call worker to start, actor: {}.", actor.getId()); - ObjectRef result = null; + ObjectRef result; // python if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() + .setCheckpointId(checkpointId) + .build(); result = ((PyActorHandle) actor) - .task(PyActorMethod.of("start", Boolean.class)).remote(); + .task(PyActorMethod.of("rollback"), + checkpointIdPb.toByteArray() + ).remote(); } else { // java - result = ((ActorHandle) actor).task(JobWorker::start).remote(); + result = ((ActorHandle) actor) + .task(JobWorker::rollback, checkpointId, System.currentTimeMillis()).remote(); } LOG.info("Finished calling worker to start."); @@ -82,4 +94,92 @@ public class RemoteCallWorker { return result; } + public static ObjectRef triggerCheckpoint(BaseActorHandle actor, Long barrierId) { + // python + if (actor instanceof PyActorHandle) { + RemoteCall.Barrier barrierPb = RemoteCall.Barrier.newBuilder().setId(barrierId).build(); + return ((PyActorHandle) actor).task( + PyActorMethod.of("commit"), barrierPb.toByteArray()).remote(); + } else { + // java + return ((ActorHandle) actor).task(JobWorker::triggerCheckpoint, barrierId) + .remote(); + } + } + + public static void clearExpiredCheckpointParallel( + List actors, Long stateCheckpointId, + Long queueCheckpointId) { + if (LOG.isInfoEnabled()) { + LOG.info("Call worker clearExpiredCheckpoint, state checkpoint id is {}," + + " queue checkpoint id is {}.", stateCheckpointId, queueCheckpointId); + } + + List result = + checkpointCompleteCommonCallTwoWay(actors, stateCheckpointId, queueCheckpointId, + "clear_expired_cp", JobWorker::clearExpiredCheckpoint); + + if (LOG.isInfoEnabled()) { + result.forEach( + obj -> LOG.info("Finish call worker clearExpiredCheckpointParallel, ret is {}.", obj)); + } + } + + public static void notifyCheckpointTimeoutParallel( + List actors, + Long checkpointId) { + LOG.info("Call worker notifyCheckpointTimeoutParallel, checkpoint id is {}", checkpointId); + + actors.forEach(actor -> { + if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() + .setCheckpointId(checkpointId) + .build(); + ((PyActorHandle) actor).task(PyActorMethod.of("notify_checkpoint_timeout"), + checkpointIdPb.toByteArray()).remote(); + } else { + ((ActorHandle) actor).task(JobWorker::notifyCheckpointTimeout, checkpointId) + .remote(); + } + }); + + LOG.info("Finish call worker notifyCheckpointTimeoutParallel."); + } + + private static List checkpointCompleteCommonCallTwoWay( + List actors, Long stateCheckpointId, Long queueCheckpointId, + String pyFuncName, RayFunc3 rayFunc) { + List> waitFor = + checkpointCompleteCommonCall(actors, stateCheckpointId, queueCheckpointId, + pyFuncName, rayFunc); + return Ray.get(waitFor); + } + + private static List> checkpointCompleteCommonCall( + List actors, + Long stateCheckpointId, Long queueCheckpointId, + String pyFuncName, + RayFunc3 rayFunc) { + List> waitFor = new ArrayList<>(); + actors.forEach(actor -> { + // python + if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId stateCheckpointIdPb = RemoteCall.CheckpointId.newBuilder() + .setCheckpointId(stateCheckpointId) + .build(); + + RemoteCall.CheckpointId queueCheckpointIdPb = RemoteCall.CheckpointId.newBuilder() + .setCheckpointId(queueCheckpointId) + .build(); + waitFor.add(((PyActorHandle) actor).task(PyActorMethod.of(pyFuncName), + stateCheckpointIdPb.toByteArray(), queueCheckpointIdPb.toByteArray()).remote()); + } else { + // java + waitFor.add(((ActorHandle) actor).task(rayFunc, stateCheckpointId, queueCheckpointId) + .remote()); + } + }); + return waitFor; + } + } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java new file mode 100644 index 000000000..db7937159 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/AsyncRemoteCaller.java @@ -0,0 +1,131 @@ +package io.ray.streaming.runtime.rpc.async; + +import io.ray.api.ActorHandle; +import io.ray.api.BaseActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.PyActorHandle; +import io.ray.api.function.PyActorMethod; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.message.CallResult; +import io.ray.streaming.runtime.rpc.PbResultParser; +import io.ray.streaming.runtime.rpc.async.RemoteCallPool.Callback; +import io.ray.streaming.runtime.rpc.async.RemoteCallPool.ExceptionHandler; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.worker.JobWorker; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings("unchecked") +public class AsyncRemoteCaller { + + private static final Logger LOG = LoggerFactory.getLogger(AsyncRemoteCaller.class); + private RemoteCallPool remoteCallPool = new RemoteCallPool(); + + /** + * Call JobWorker::checkIfNeedRollback async + * + * @param actor JobWorker actor + * @param callback callback function on success + * @param onException callback function on exception + */ + public void checkIfNeedRollbackAsync( + BaseActorHandle actor, Callback callback, + ExceptionHandler onException) { + if (actor instanceof PyActorHandle) { + // python + remoteCallPool.bindCallback( + ((PyActorHandle) actor).task(PyActorMethod.of("check_if_need_rollback")).remote(), + (obj) -> { + byte[] res = (byte[]) obj; + callback.handle(PbResultParser.parseBoolResult(res)); + }, onException); + } else { + // java + remoteCallPool.bindCallback( + ((ActorHandle) actor).task(JobWorker::checkIfNeedRollback, + System.currentTimeMillis()).remote(), callback, onException); + } + } + + /** + * Call JobWorker::rollback async + * + * @param actor JobWorker actor + * @param callback callback function on success + * @param onException callback function on exception + */ + public void rollback( + BaseActorHandle actor, + final Long checkpointId, + Callback> callback, + ExceptionHandler onException) { + // python + if (actor instanceof PyActorHandle) { + RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() + .setCheckpointId(checkpointId) + .build(); + ObjectRef call = ((PyActorHandle) actor).task(PyActorMethod.of("rollback"), + checkpointIdPb.toByteArray()).remote(); + remoteCallPool.bindCallback(call, obj -> + callback.handle(PbResultParser.parseRollbackResult((byte[]) obj)), onException); + } else { + // java + ObjectRef call = ((ActorHandle) actor).task( + JobWorker::rollback, checkpointId, System.currentTimeMillis()).remote(); + remoteCallPool.bindCallback(call, obj -> { + CallResult res = (CallResult) obj; + callback.handle(res); + }, onException); + } + } + + /** + * Call JobWorker::rollback async in batch + * + * @param actors JobWorker actor list + * @param callback callback function on success + * @param onException callback function on exception + */ + public void batchRollback( + List actors, final Long checkpointId, + Collection abnormalQueues, + Callback>> callback, + ExceptionHandler onException) { + List> rayCallList = new ArrayList<>(); + Map isPyActor = new HashMap<>(); + for (int i = 0; i < actors.size(); ++i) { + BaseActorHandle actor = actors.get(i); + ObjectRef call; + if (actor instanceof PyActorHandle) { + isPyActor.put(i, true); + RemoteCall.CheckpointId checkpointIdPb = RemoteCall.CheckpointId.newBuilder() + .setCheckpointId(checkpointId) + .build(); + call = ((PyActorHandle) actor).task(PyActorMethod.of("rollback"), + checkpointIdPb.toByteArray()).remote(); + } else { + // java + call = ((ActorHandle) actor).task(JobWorker::rollback, checkpointId, + System.currentTimeMillis()).remote(); + } + rayCallList.add(call); + } + remoteCallPool.bindCallback(rayCallList, objList -> { + List> results = new ArrayList<>(); + for (int i = 0; i < objList.size(); ++i) { + Object obj = objList.get(i); + if (isPyActor.getOrDefault(i, false)) { + results.add(PbResultParser.parseRollbackResult((byte[]) obj)); + } else { + results.add((CallResult) obj); + } + } + callback.handle(results); + }, onException); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java new file mode 100644 index 000000000..52e9e5651 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/rpc/async/RemoteCallPool.java @@ -0,0 +1,189 @@ +package io.ray.streaming.runtime.rpc.async; + +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.api.WaitResult; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class RemoteCallPool implements Runnable { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteCallPool.class); + private static final int WAIT_TIME_MS = 5; + private static final long WARNING_PERIOD = 10000; + private final List pendingObjectBundles = new LinkedList<>(); + private Map> singletonHandlerMap = new ConcurrentHashMap<>(); + private Map>> bundleHandlerMap = + new ConcurrentHashMap<>(); + private Map> bundleExceptionHandlerMap = + new ConcurrentHashMap<>(); + private ThreadPoolExecutor callBackPool = new ThreadPoolExecutor( + 2, Runtime.getRuntime().availableProcessors(), + 1, TimeUnit.MINUTES, new LinkedBlockingQueue<>(), + new CallbackThreadFactory()); + private volatile boolean stop = false; + + public RemoteCallPool() { + Thread t = new Thread(Ray.wrapRunnable(this), "remote-pool-loop"); + t.setUncaughtExceptionHandler((thread, throwable) -> + LOG.error("Error in remote call pool thread.", throwable) + ); + t.start(); + } + + @SuppressWarnings("unchecked") + public void bindCallback( + ObjectRef obj, Callback callback, + ExceptionHandler onException) { + List objectRefList = Collections.singletonList(obj); + RemoteCallBundle bundle = new RemoteCallBundle(objectRefList, + true); + singletonHandlerMap.put(bundle, (Callback) callback); + bundleExceptionHandlerMap.put(bundle, onException); + synchronized (pendingObjectBundles) { + pendingObjectBundles.add(bundle); + } + } + + public void bindCallback( + List> objectBundle, Callback> callback, + ExceptionHandler onException) { + RemoteCallBundle bundle = new RemoteCallBundle(objectBundle, false); + bundleHandlerMap.put(bundle, callback); + bundleExceptionHandlerMap.put(bundle, onException); + synchronized (pendingObjectBundles) { + pendingObjectBundles.add(bundle); + } + } + + public void stop() { + stop = true; + } + + public void run() { + while (!stop) { + try { + if (pendingObjectBundles.isEmpty()) { + Thread.sleep(WAIT_TIME_MS); + continue; + } + synchronized (pendingObjectBundles) { + Iterator itr = pendingObjectBundles.iterator(); + while (itr.hasNext()) { + RemoteCallBundle bundle = itr.next(); + WaitResult waitResult = + Ray.wait(bundle.objects, bundle.objects.size(), WAIT_TIME_MS); + List> readyObjs = waitResult.getReady(); + if (readyObjs.size() != bundle.objects.size()) { + long now = System.currentTimeMillis(); + long waitingTime = now - bundle.createTime; + if (waitingTime > WARNING_PERIOD && now - bundle.lastWarnTs > WARNING_PERIOD) { + bundle.lastWarnTs = now; + LOG.warn("Bundle has being waiting for {} ms, bundle = {}.", waitingTime, bundle); + } + continue; + } + + ExceptionHandler exceptionHandler = bundleExceptionHandlerMap.get(bundle); + if (bundle.isSingletonBundle) { + callBackPool.execute(Ray.wrapRunnable(() -> { + try { + singletonHandlerMap.get(bundle).handle(readyObjs.get(0).get()); + singletonHandlerMap.remove(bundle); + } catch (Throwable th) { + LOG.error("Error when get object, objectId = {}.", readyObjs.get(0).toString(), + th); + if (exceptionHandler != null) { + exceptionHandler.handle(th); + } + } + })); + } else { + List results = + readyObjs.stream().map(ObjectRef::get).collect(Collectors.toList()); + List resultIds = + readyObjs.stream().map(ObjectRef::toString).collect(Collectors.toList()); + callBackPool.execute(Ray.wrapRunnable(() -> { + try { + bundleHandlerMap.get(bundle).handle(results); + bundleHandlerMap.remove(bundle); + } catch (Throwable th) { + LOG.error("Error when get object, objectIds = {}.", resultIds, th); + if (exceptionHandler != null) { + exceptionHandler.handle(th); + } + } + })); + } + itr.remove(); + } + } + + } catch (Exception e) { + LOG.error("Exception in wait loop.", e); + } + } + LOG.info("Wait loop finished."); + } + + @FunctionalInterface + public interface ExceptionHandler { + + void handle(T object); + } + + @FunctionalInterface + public interface Callback { + + void handle(T object) throws Throwable; + } + + private static class RemoteCallBundle { + + List> objects; + boolean isSingletonBundle; + long lastWarnTs = System.currentTimeMillis(); + long createTime = System.currentTimeMillis(); + + RemoteCallBundle(List> objects, boolean isSingletonBundle) { + this.objects = objects; + this.isSingletonBundle = isSingletonBundle; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("["); + objects.forEach(rayObj -> sb.append(rayObj.toString()).append(",")); + sb.append("]"); + return sb.toString(); + } + } + + static class CallbackThreadFactory implements ThreadFactory { + + private AtomicInteger cnt = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setUncaughtExceptionHandler((thread, throwable) -> LOG.error("Callback err.", throwable)); + t.setName("callback-thread-" + cnt.getAndIncrement()); + return t; + } + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataMessage.java deleted file mode 100644 index 6c8f08d8e..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataMessage.java +++ /dev/null @@ -1,55 +0,0 @@ -package io.ray.streaming.runtime.transfer; - -import java.nio.ByteBuffer; - -/** - * DataMessage represents data between upstream and downstream operator - */ -public class DataMessage implements Message { - - private final ByteBuffer body; - private final long msgId; - private final long timestamp; - private final String channelId; - - public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) { - this.body = body; - this.timestamp = timestamp; - this.msgId = msgId; - this.channelId = channelId; - } - - @Override - public ByteBuffer body() { - return body; - } - - @Override - public long timestamp() { - return timestamp; - } - - /** - * @return message id - */ - public long msgId() { - return msgId; - } - - /** - * @return string id of channel where data is coming from - */ - public String channelId() { - return channelId; - } - - @Override - public String toString() { - return "DataMessage{" + - "body=" + body + - ", msgId=" + msgId + - ", timestamp=" + timestamp + - ", channelId='" + channelId + '\'' + - '}'; - } -} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java index 3cdf15a07..f10571796 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataReader.java @@ -4,11 +4,22 @@ import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.config.types.TransferChannelType; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus; +import io.ray.streaming.runtime.transfer.channel.ChannelUtils; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import io.ray.streaming.runtime.transfer.message.BarrierMessage; +import io.ray.streaming.runtime.transfer.message.ChannelMessage; +import io.ray.streaming.runtime.transfer.message.DataMessage; import io.ray.streaming.runtime.util.Platform; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Queue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -22,7 +33,20 @@ public class DataReader { private static final Logger LOG = LoggerFactory.getLogger(DataReader.class); private long nativeReaderPtr; - private Queue buf = new LinkedList<>(); + // params set by getBundleNative: bundle data address + size + private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24); + // We use direct buffer to reduce gc overhead and memory copy. + private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0); + private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH); + + private final Map queueCreationStatusMap = new HashMap<>(); + private Queue buf = new LinkedList<>(); + + { + getBundleParams.order(ByteOrder.nativeOrder()); + bundleData.order(ByteOrder.nativeOrder()); + bundleMeta.order(ByteOrder.nativeOrder()); + } /** * @param inputChannels input channels ids @@ -32,6 +56,7 @@ public class DataReader { public DataReader( List inputChannels, List fromActors, + Map checkpoints, StreamingWorkerConfig workerConfig) { Preconditions.checkArgument(inputChannels.size() > 0); Preconditions.checkArgument(inputChannels.size() == fromActors.size()); @@ -39,11 +64,16 @@ public class DataReader { new ChannelCreationParametersBuilder().buildInputQueueParameters(inputChannels, fromActors); byte[][] inputChannelsBytes = inputChannels.stream() .map(ChannelId::idStrToBytes).toArray(byte[][]::new); - long[] seqIds = new long[inputChannels.size()]; + + // get sequence ID and message ID from OffsetInfo long[] msgIds = new long[inputChannels.size()]; for (int i = 0; i < inputChannels.size(); i++) { - seqIds[i] = 0; - msgIds[i] = 0; + String channelId = inputChannels.get(i); + if (!checkpoints.containsKey(channelId)) { + msgIds[i] = 0; + continue; + } + msgIds[i] = checkpoints.get(inputChannels.get(i)).getStreamingMsgId(); } long timerInterval = workerConfig.transferConfig.readerTimerIntervalMs(); TransferChannelType channelType = workerConfig.transferConfig.channelType(); @@ -51,33 +81,34 @@ public class DataReader { if (TransferChannelType.MEMORY_CHANNEL == channelType) { isMock = true; } - boolean isRecreate = workerConfig.transferConfig.readerIsRecreate(); + // create native reader + List creationStatus = new ArrayList<>(); this.nativeReaderPtr = createDataReaderNative( initialParameters, inputChannelsBytes, - seqIds, msgIds, timerInterval, - isRecreate, + creationStatus, ChannelUtils.toNativeConf(workerConfig), isMock ); - LOG.info("Create DataReader succeed for worker: {}.", - workerConfig.workerInternalConfig.workerName()); + for (int i = 0; i < inputChannels.size(); ++i) { + queueCreationStatusMap + .put(inputChannels.get(i), ChannelCreationStatus.fromInt(creationStatus.get(i))); + } + LOG.info("Create DataReader succeed for worker: {}, creation status={}.", + workerConfig.workerInternalConfig.workerName(), queueCreationStatusMap); } - // params set by getBundleNative: bundle data address + size - private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24); - // We use direct buffer to reduce gc overhead and memory copy. - private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0); - private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH); - - { - getBundleParams.order(ByteOrder.nativeOrder()); - bundleData.order(ByteOrder.nativeOrder()); - bundleMeta.order(ByteOrder.nativeOrder()); - } + private static native long createDataReaderNative( + ChannelCreationParametersBuilder initialParameters, + byte[][] inputChannels, + long[] msgIds, + long timerInterval, + List creationStatus, + byte[] configBytes, + boolean isMock); /** * Read message from input channels, if timeout, return null. @@ -85,26 +116,21 @@ public class DataReader { * @param timeoutMillis timeout * @return message or null */ - public DataMessage read(long timeoutMillis) { + public ChannelMessage read(long timeoutMillis) { if (buf.isEmpty()) { getBundle(timeoutMillis); // if bundle not empty. empty message still has data size + seqId + msgId if (bundleData.position() < bundleData.limit()) { BundleMeta bundleMeta = new BundleMeta(this.bundleMeta); + String channelID = bundleMeta.getChannelID(); + long timestamp = bundleMeta.getBundleTs(); // barrier if (bundleMeta.getBundleType() == DataBundleType.BARRIER) { - throw new UnsupportedOperationException( - "Unsupported bundle type " + bundleMeta.getBundleType()); + buf.offer(getBarrier(bundleData, channelID, timestamp)); } else if (bundleMeta.getBundleType() == DataBundleType.BUNDLE) { - String channelID = bundleMeta.getChannelID(); - long timestamp = bundleMeta.getBundleTs(); for (int i = 0; i < bundleMeta.getMessageListSize(); i++) { buf.offer(getDataMessage(bundleData, channelID, timestamp)); } - } else if (bundleMeta.getBundleType() == DataBundleType.EMPTY) { - long messageId = bundleMeta.getLastMessageId(); - buf.offer(new DataMessage(null, bundleMeta.getBundleTs(), - messageId, bundleMeta.getChannelID())); } } } @@ -114,6 +140,31 @@ public class DataReader { return buf.poll(); } + public ChannelRecoverInfo getQueueRecoverInfo() { + return new ChannelRecoverInfo(queueCreationStatusMap); + } + + private String getQueueIdString(ByteBuffer buffer) { + byte[] bytes = new byte[ChannelId.ID_LENGTH]; + buffer.get(bytes); + return ChannelId.idBytesToStr(bytes); + } + + private BarrierMessage getBarrier(ByteBuffer bundleData, String channelID, long timestamp) { + ByteBuffer offsetsInfoBytes = ByteBuffer.wrap(getOffsetsInfoNative(nativeReaderPtr)); + offsetsInfoBytes.order(ByteOrder.nativeOrder()); + BarrierOffsetInfo offsetInfo = new BarrierOffsetInfo(offsetsInfoBytes); + DataMessage message = getDataMessage(bundleData, channelID, timestamp); + BarrierItem barrierItem = new BarrierItem(message, offsetInfo); + return new BarrierMessage( + message.getMsgId(), + message.getTimestamp(), + message.getChannelId(), + barrierItem.getData(), + barrierItem.getGlobalBarrierId(), + barrierItem.getBarrierOffsetInfo().getQueueOffsetInfo()); + } + private DataMessage getDataMessage(ByteBuffer bundleData, String channelID, long timestamp) { int dataSize = bundleData.getInt(); // msgId @@ -161,22 +212,14 @@ public class DataReader { LOG.info("Finish closing DataReader."); } - private static native long createDataReaderNative( - ChannelCreationParametersBuilder initialParameters, - byte[][] inputChannels, - long[] seqIds, - long[] msgIds, - long timerInterval, - boolean isRecreate, - byte[] configBytes, - boolean isMock); - private native void getBundleNative( long nativeReaderPtr, long timeoutMillis, long params, long metaAddress); + private native byte[] getOffsetsInfoNative(long nativeQueueConsumerPtr); + private native void stopReaderNative(long nativeReaderPtr); private native void closeReaderNative(long nativeReaderPtr); @@ -193,7 +236,16 @@ public class DataReader { } } - static class BundleMeta { + public enum BarrierType { + GLOBAL_BARRIER(0); + private int code; + + BarrierType(int code) { + this.code = code; + } + } + + class BundleMeta { // kMessageBundleHeaderSize + kUniqueIDSize: // magicNum(4b) + bundleTs(8b) + lastMessageId(8b) + messageListSize(4b) @@ -226,13 +278,7 @@ public class DataReader { } // rawBundleSize rawBundleSize = buffer.getInt(); - channelID = getQidString(buffer); - } - - private String getQidString(ByteBuffer buffer) { - byte[] bytes = new byte[ChannelId.ID_LENGTH]; - buffer.get(bytes); - return ChannelId.idBytesToStr(bytes); + channelID = getQueueIdString(buffer); } public int getMagicNum() { @@ -264,4 +310,73 @@ public class DataReader { } } + class BarrierOffsetInfo { + + private int queueSize; + private Map queueOffsetInfo; + + public BarrierOffsetInfo(ByteBuffer buffer) { + // deserialization offset + queueSize = buffer.getInt(); + queueOffsetInfo = new HashMap<>(queueSize); + for (int i = 0; i < queueSize; ++i) { + String qid = getQueueIdString(buffer); + long streamingMsgId = buffer.getLong(); + queueOffsetInfo.put(qid, new OffsetInfo(streamingMsgId)); + } + } + + public int getQueueSize() { + return queueSize; + } + + public Map getQueueOffsetInfo() { + return queueOffsetInfo; + } + } + + class BarrierItem { + + BarrierOffsetInfo barrierOffsetInfo; + private long msgId; + private BarrierType barrierType; + private long globalBarrierId; + private ByteBuffer data; + + public BarrierItem(DataMessage message, BarrierOffsetInfo barrierOffsetInfo) { + this.barrierOffsetInfo = barrierOffsetInfo; + msgId = message.getMsgId(); + ByteBuffer buffer = message.body(); + // c++ use native order, so use native order here. + buffer.order(ByteOrder.nativeOrder()); + int barrierTypeInt = buffer.getInt(); + globalBarrierId = buffer.getLong(); + // dataSize includes: barrier type(32 bit), globalBarrierId, data + data = buffer.slice(); + data.order(ByteOrder.nativeOrder()); + buffer.position(buffer.limit()); + barrierType = BarrierType.GLOBAL_BARRIER; + } + + public long getBarrierMsgId() { + return msgId; + } + + public BarrierType getBarrierType() { + return barrierType; + } + + public long getGlobalBarrierId() { + return globalBarrierId; + } + + public ByteBuffer getData() { + return data; + } + + public BarrierOffsetInfo getBarrierOffsetInfo() { + return barrierOffsetInfo; + } + } + } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java index a8cebabb0..55729c7fb 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/DataWriter.java @@ -4,10 +4,15 @@ import com.google.common.base.Preconditions; import io.ray.api.BaseActorHandle; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.config.types.TransferChannelType; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.transfer.channel.ChannelUtils; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; import io.ray.streaming.runtime.util.Platform; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -22,6 +27,7 @@ public class DataWriter { private long nativeWriterPtr; private ByteBuffer buffer = ByteBuffer.allocateDirect(0); private long bufferAddress; + private List outputChannels; { ensureBuffer(0); @@ -31,21 +37,33 @@ public class DataWriter { * @param outputChannels output channels ids * @param toActors downstream output actors * @param workerConfig configuration + * @param checkpoints offset of each channels */ public DataWriter( List outputChannels, List toActors, + Map checkpoints, StreamingWorkerConfig workerConfig) { Preconditions.checkArgument(!outputChannels.isEmpty()); Preconditions.checkArgument(outputChannels.size() == toActors.size()); + this.outputChannels = outputChannels; + ChannelCreationParametersBuilder initialParameters = new ChannelCreationParametersBuilder().buildOutputQueueParameters(outputChannels, toActors); + byte[][] outputChannelsBytes = outputChannels.stream() .map(ChannelId::idStrToBytes).toArray(byte[][]::new); long channelSize = workerConfig.transferConfig.channelSize(); + + // load message id from checkpoints long[] msgIds = new long[outputChannels.size()]; for (int i = 0; i < outputChannels.size(); i++) { - msgIds[i] = 0; + String channelId = outputChannels.get(i); + if (!checkpoints.containsKey(channelId)) { + msgIds[i] = 0; + continue; + } + msgIds[i] = checkpoints.get(channelId).getStreamingMsgId(); } TransferChannelType channelType = workerConfig.transferConfig.channelType(); boolean isMock = false; @@ -64,6 +82,14 @@ public class DataWriter { workerConfig.workerInternalConfig.workerName()); } + private static native long createWriterNative( + ChannelCreationParametersBuilder initialParameters, + byte[][] outputQueueIds, + long[] msgIds, + long channelSize, + byte[] confBytes, + boolean isMock); + /** * Write msg into the specified channel * @@ -82,9 +108,8 @@ public class DataWriter { * Write msg into the specified channels * * @param ids channel ids - * @param item message item data section is specified by [position, limit). item doesn't have - * to - * be a direct buffer. + * @param item message item data section is specified by [position, limit). + * item doesn't have to be a direct buffer. */ public void write(Set ids, ByteBuffer item) { int size = item.remaining(); @@ -104,6 +129,27 @@ public class DataWriter { } } + public Map getOutputCheckpoints() { + long[] msgId = getOutputMsgIdNative(nativeWriterPtr); + Map res = new HashMap<>(outputChannels.size()); + for (int i = 0; i < outputChannels.size(); ++i) { + res.put(outputChannels.get(i), new OffsetInfo(msgId[i])); + } + LOG.info("got output points, {}.", res); + return res; + } + + public void broadcastBarrier(long checkpointId, ByteBuffer attach) { + LOG.info("Broadcast barrier, cpId={}.", checkpointId); + Preconditions.checkArgument(attach.order() == ByteOrder.nativeOrder()); + broadcastBarrierNative(nativeWriterPtr, checkpointId, attach.array()); + } + + public void clearCheckpoint(long checkpointId) { + LOG.info("Producer clear checkpoint, checkpointId={}.", checkpointId); + clearCheckpointNative(nativeWriterPtr, checkpointId); + } + /** * stop writer */ @@ -124,14 +170,6 @@ public class DataWriter { LOG.info("Finish closing data writer."); } - private static native long createWriterNative( - ChannelCreationParametersBuilder initialParameters, - byte[][] outputQueueIds, - long[] msgIds, - long channelSize, - byte[] confBytes, - boolean isMock); - private native long writeMessageNative( long nativeQueueProducerPtr, long nativeIdPtr, long address, int size); @@ -139,4 +177,15 @@ public class DataWriter { private native void closeWriterNative(long nativeQueueProducerPtr); + private native long[] getOutputMsgIdNative(long nativeQueueProducerPtr); + + private native void broadcastBarrierNative( + long nativeQueueProducerPtr, long checkpointId, + byte[] data); + + private native void clearCheckpointNative( + long nativeQueueProducerPtr, + long checkpointId + ); + } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/Message.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/Message.java deleted file mode 100644 index f48cb6f77..000000000 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/Message.java +++ /dev/null @@ -1,22 +0,0 @@ -package io.ray.streaming.runtime.transfer; - -import java.nio.ByteBuffer; - -public interface Message { - - /** - * Message data - *

- * Message body is a direct byte buffer, which may be invalid after call next - * DataReader#getBundleNative. Please consume this buffer fully - * before next call getBundleNative. - * - * @return message body - */ - ByteBuffer body(); - - /** - * @return timestamp when item is written by upstream DataWriter - */ - long timestamp(); -} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelId.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java similarity index 97% rename from streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelId.java rename to streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java index 75904e19e..07e98ae3f 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelId.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelId.java @@ -1,4 +1,4 @@ -package io.ray.streaming.runtime.transfer; +package io.ray.streaming.runtime.transfer.channel; import com.google.common.base.FinalizablePhantomReference; import com.google.common.base.FinalizableReferenceQueue; @@ -41,47 +41,6 @@ public class ChannelId { this.nativeIdPtr = nativeIdPtr; } - public byte[] getBytes() { - return bytes; - } - - public ByteBuffer getBuffer() { - return buffer; - } - - public long getAddress() { - return address; - } - - public long getNativeIdPtr() { - if (nativeIdPtr == 0) { - throw new IllegalStateException("native ID not available"); - } - return nativeIdPtr; - } - - @Override - public String toString() { - return strId; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - ChannelId that = (ChannelId) o; - return strId.equals(that.strId); - } - - @Override - public int hashCode() { - return strId.hashCode(); - } - private static native long createNativeId(long idAddress); private static native void destroyNativeId(long nativeIdPtr); @@ -164,7 +123,7 @@ public class ChannelId { * @param id hex string representation of channel id * @return bytes representation of channel id */ - static byte[] idStrToBytes(String id) { + public static byte[] idStrToBytes(String id) { byte[] idBytes = BaseEncoding.base16().decode(id.toUpperCase()); assert idBytes.length == ChannelId.ID_LENGTH; return idBytes; @@ -174,10 +133,51 @@ public class ChannelId { * @param id bytes representation of channel id * @return hex string representation of channel id */ - static String idBytesToStr(byte[] id) { + public static String idBytesToStr(byte[] id) { assert id.length == ChannelId.ID_LENGTH; return BaseEncoding.base16().encode(id).toLowerCase(); } + public byte[] getBytes() { + return bytes; + } + + public ByteBuffer getBuffer() { + return buffer; + } + + public long getAddress() { + return address; + } + + public long getNativeIdPtr() { + if (nativeIdPtr == 0) { + throw new IllegalStateException("native ID not available"); + } + return nativeIdPtr; + } + + @Override + public String toString() { + return strId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ChannelId that = (ChannelId) o; + return strId.equals(that.strId); + } + + @Override + public int hashCode() { + return strId.hashCode(); + } + } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java new file mode 100644 index 000000000..584f411ee --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelRecoverInfo.java @@ -0,0 +1,60 @@ +package io.ray.streaming.runtime.transfer.channel; + +import com.google.common.base.MoreObjects; +import java.io.Serializable; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class ChannelRecoverInfo implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(ChannelRecoverInfo.class); + public Map queueCreationStatusMap; + + + public ChannelRecoverInfo(Map queueCreationStatusMap) { + this.queueCreationStatusMap = queueCreationStatusMap; + } + + public Set getDataLostQueues() { + Set dataLostQueues = new HashSet<>(); + queueCreationStatusMap.forEach((q, status) -> { + if (status.equals(ChannelCreationStatus.DataLost)) { + dataLostQueues.add(q); + } + }); + return dataLostQueues; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("dataLostQueues", getDataLostQueues()) + .toString(); + } + + public enum ChannelCreationStatus { + FreshStarted(0), + PullOk(1), + Timeout(2), + DataLost(3); + + private int id; + + ChannelCreationStatus(int id) { + this.id = id; + } + + public static ChannelCreationStatus fromInt(int id) { + for (ChannelCreationStatus status : ChannelCreationStatus.values()) { + if (status.id == id) { + return status; + } + } + return null; + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelUtils.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java similarity index 94% rename from streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelUtils.java rename to streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java index c62b21018..74e813134 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/ChannelUtils.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/ChannelUtils.java @@ -1,4 +1,4 @@ -package io.ray.streaming.runtime.transfer; +package io.ray.streaming.runtime.transfer.channel; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.generated.Streaming; @@ -10,7 +10,7 @@ public class ChannelUtils { private static final Logger LOGGER = LoggerFactory.getLogger(ChannelUtils.class); - static byte[] toNativeConf(StreamingWorkerConfig workerConfig) { + public static byte[] toNativeConf(StreamingWorkerConfig workerConfig) { Streaming.StreamingConfig.Builder builder = Streaming.StreamingConfig.newBuilder(); // job name diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java new file mode 100644 index 000000000..5c3ea02a7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/channel/OffsetInfo.java @@ -0,0 +1,31 @@ +package io.ray.streaming.runtime.transfer.channel; + +import com.google.common.base.MoreObjects; +import java.io.Serializable; + +/** + * This data structure contains offset used by streaming queue. + */ +public class OffsetInfo implements Serializable { + + private long streamingMsgId; + + public OffsetInfo(long streamingMsgId) { + this.streamingMsgId = streamingMsgId; + } + + public long getStreamingMsgId() { + return streamingMsgId; + } + + public void setStreamingMsgId(long streamingMsgId) { + this.streamingMsgId = streamingMsgId; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("streamingMsgId", streamingMsgId) + .toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java new file mode 100644 index 000000000..f4d909ce7 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/exception/ChannelInterruptException.java @@ -0,0 +1,22 @@ +package io.ray.streaming.runtime.transfer.exception; + +import io.ray.streaming.runtime.transfer.DataReader; +import io.ray.streaming.runtime.transfer.DataWriter; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import java.nio.ByteBuffer; + +/** + * when {@link DataReader#stop()} or {@link DataWriter#stop()} is called, this exception might be + * thrown in {@link DataReader#read(long)} and {@link DataWriter#write(ChannelId, ByteBuffer)}, + * which means the read/write operation is failed. + */ +public class ChannelInterruptException extends RuntimeException { + + public ChannelInterruptException() { + super(); + } + + public ChannelInterruptException(String message) { + super(message); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java new file mode 100644 index 000000000..ffc694c53 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/BarrierMessage.java @@ -0,0 +1,34 @@ +package io.ray.streaming.runtime.transfer.message; + +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import java.nio.ByteBuffer; +import java.util.Map; + + +public class BarrierMessage extends ChannelMessage { + + private final ByteBuffer data; + private final long checkpointId; + private final Map inputOffsets; + + public BarrierMessage( + long msgId, long timestamp, String channelId, + ByteBuffer data, long checkpointId, Map inputOffsets) { + super(msgId, timestamp, channelId); + this.data = data; + this.checkpointId = checkpointId; + this.inputOffsets = inputOffsets; + } + + public ByteBuffer getData() { + return data; + } + + public long getCheckpointId() { + return checkpointId; + } + + public Map getInputOffsets() { + return inputOffsets; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java new file mode 100644 index 000000000..6bfa4dca5 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/ChannelMessage.java @@ -0,0 +1,26 @@ +package io.ray.streaming.runtime.transfer.message; + +public class ChannelMessage { + + private final long msgId; + private final long timestamp; + private final String channelId; + + public ChannelMessage(long msgId, long timestamp, String channelId) { + this.msgId = msgId; + this.timestamp = timestamp; + this.channelId = channelId; + } + + public long getMsgId() { + return msgId; + } + + public long getTimestamp() { + return timestamp; + } + + public String getChannelId() { + return channelId; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java new file mode 100644 index 000000000..b3cf779bf --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/transfer/message/DataMessage.java @@ -0,0 +1,21 @@ +package io.ray.streaming.runtime.transfer.message; + +import java.nio.ByteBuffer; + + +/** + * DataMessage represents data between upstream and downstream operators. + */ +public class DataMessage extends ChannelMessage { + + private final ByteBuffer body; + + public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) { + super(msgId, timestamp, channelId); + this.body = body; + } + + public ByteBuffer body() { + return body; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java new file mode 100644 index 000000000..c32d2ef4f --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/CheckpointStateUtil.java @@ -0,0 +1,59 @@ +package io.ray.streaming.runtime.util; + +import io.ray.streaming.runtime.context.ContextBackend; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Handle exception for checkpoint state + */ +public class CheckpointStateUtil { + + private static final Logger LOG = LoggerFactory.getLogger(CheckpointStateUtil.class); + + /** + * DO NOT ALLOW GET EXCEPTION WHEN LOADING CHECKPOINT + * + * @param checkpointState state backend + * @param cpKey checkpoint key + */ + public static byte[] get(ContextBackend checkpointState, String cpKey) { + byte[] val; + try { + val = checkpointState.get(cpKey); + } catch (Exception e) { + throw new CheckpointStateRuntimeException( + String.format("Failed to get %s from state backend.", cpKey), e); + } + return val; + } + + /** + * ALLOW PUT EXCEPTION WHEN SAVING CHECKPOINT + * + * @param checkpointState state backend + * @param key checkpoint key + * @param val checkpoint value + */ + public static void put(ContextBackend checkpointState, String key, byte[] val) { + try { + checkpointState.put(key, val); + } catch (Exception e) { + LOG.error("Failed to put key {} to state backend.", key, e); + } + } + + public static class CheckpointStateRuntimeException extends RuntimeException { + + public CheckpointStateRuntimeException() { + } + + public CheckpointStateRuntimeException(String message) { + super(message); + } + + public CheckpointStateRuntimeException(String message, Throwable cause) { + super(message, cause); + } + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java index f5120fb3a..2238e82aa 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/EnvUtil.java @@ -3,13 +3,29 @@ package io.ray.streaming.runtime.util; import io.ray.runtime.RayNativeRuntime; import io.ray.runtime.util.JniUtils; import java.lang.management.ManagementFactory; +import java.net.InetAddress; +import java.net.UnknownHostException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class EnvUtil { + private static final Logger LOG = LoggerFactory.getLogger(EnvUtil.class); + public static String getJvmPid() { return ManagementFactory.getRuntimeMXBean().getName().split("@")[0]; } + public static String getHostName() { + String hostname = ""; + try { + hostname = InetAddress.getLocalHost().getHostName(); + } catch (UnknownHostException e) { + LOG.error("Error occurs while fetching local host.", e); + } + return hostname; + } + public static void loadNativeLibraries() { // Explicitly load `RayNativeRuntime`, to make sure `core_worker_library_java` // is loaded before `streaming_java`. diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java new file mode 100644 index 000000000..66777b702 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/ResourceUtil.java @@ -0,0 +1,220 @@ +package io.ray.streaming.runtime.util; + +import com.sun.management.OperatingSystemMXBean; +import io.ray.api.id.UniqueId; +import io.ray.streaming.runtime.core.resource.Container; +import io.ray.streaming.runtime.core.resource.ContainerId; +import java.io.BufferedInputStream; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Resource Utility collects current OS and JVM resource usage information + */ +public class ResourceUtil { + + public static final Logger LOG = LoggerFactory.getLogger(ResourceUtil.class); + + /** + * Refer to: https://docs.oracle.com/javase/8/docs/jre/api/management/extension/com/sun/management/OperatingSystemMXBean.html + */ + private static OperatingSystemMXBean osmxb = + (OperatingSystemMXBean) ManagementFactory.getOperatingSystemMXBean(); + + /** + * Log current jvm process's memory detail + */ + public static void logProcessMemoryDetail() { + int mb = 1024 * 1024; + + //Getting the runtime reference from system + Runtime runtime = Runtime.getRuntime(); + + StringBuilder sb = new StringBuilder(32); + + sb.append("used memory: ").append((runtime.totalMemory() - runtime.freeMemory()) / mb) + .append(", free memory: ").append(runtime.freeMemory() / mb) + .append(", total memory: ").append(runtime.totalMemory() / mb) + .append(", max memory: ").append(runtime.maxMemory() / mb); + + if (LOG.isInfoEnabled()) { + LOG.info(sb.toString()); + } + } + + /** + * @return jvm heap usage ratio. note that one of the survivor space is not include in total + * memory while calculating this ratio. + */ + public static double getJvmHeapUsageRatio() { + Runtime runtime = Runtime.getRuntime(); + return (runtime.totalMemory() - runtime.freeMemory()) * 1.0 / runtime.maxMemory(); + } + + /** + * @return jvm heap usage(in bytes). + * note that this value doesn't include one of the survivor space. + */ + public static long getJvmHeapUsageInBytes() { + Runtime runtime = Runtime.getRuntime(); + return runtime.totalMemory() - runtime.freeMemory(); + } + + /** + * @return the total amount of physical memory in bytes. + */ + public static long getSystemTotalMemory() { + return osmxb.getTotalPhysicalMemorySize(); + } + + /** + * @return the used system physical memory in bytes + */ + public static long getSystemMemoryUsage() { + long totalMemory = osmxb.getTotalPhysicalMemorySize(); + long freeMemory = osmxb.getFreePhysicalMemorySize(); + return totalMemory - freeMemory; + } + + /** + * @return the ratio of used system physical memory. This value is a double in the [0.0,1.0] + */ + public static double getSystemMemoryUsageRatio() { + double totalMemory = osmxb.getTotalPhysicalMemorySize(); + double freeMemory = osmxb.getFreePhysicalMemorySize(); + double ratio = freeMemory / totalMemory; + return 1 - ratio; + } + + /** + * @return the cpu load for current jvm process. This value is a double in the [0.0,1.0] + */ + public static double getProcessCpuUsage() { + return osmxb.getProcessCpuLoad(); + } + + /** + * @return the system cpu usage. + * This value is a double in the [0.0,1.0] + * We will try to use `vsar` to get cpu usage by default, + * and use MXBean if any exception raised. + */ + public static double getSystemCpuUsage() { + double cpuUsage = 0.0; + try { + cpuUsage = getSystemCpuUtilByVsar(); + } catch (Exception e) { + cpuUsage = getSystemCpuUtilByMXBean(); + } + return cpuUsage; + } + + /** + * Returns the "recent cpu usage" for the whole system. This value is a double in the [0.0,1.0] + * interval. A value of 0.0 means that all CPUs were idle during the recent period of time + * observed, while a value of 1.0 means that all CPUs were actively running 100% of the time + * during the recent period being observed + */ + public static double getSystemCpuUtilByMXBean() { + return osmxb.getSystemCpuLoad(); + } + + /** + * Get system cpu util by vsar + */ + public static double getSystemCpuUtilByVsar() throws Exception { + double cpuUsageFromVsar = 0.0; + String[] vsarCpuCommand = {"/bin/sh", "-c", "vsar --check --cpu -s util"}; + try { + Process proc = Runtime.getRuntime().exec(vsarCpuCommand); + BufferedInputStream bis = new BufferedInputStream(proc.getInputStream()); + BufferedReader br = new BufferedReader(new InputStreamReader(bis)); + String line; + List processPidList = new ArrayList<>(); + while ((line = br.readLine()) != null) { + processPidList.add(line); + } + if (!processPidList.isEmpty()) { + String[] split = processPidList.get(0).split("="); + cpuUsageFromVsar = Double.parseDouble(split[1]) / 100.0D; + } else { + throw new IOException("Vsar check cpu usage failed, maybe vsar is not installed."); + } + } catch (Exception e) { + LOG.warn("Failed to get cpu usage by vsar.", e); + throw e; + } + return cpuUsageFromVsar; + } + + /** + * @returns the system load average for the last minute + */ + public static double getSystemLoadAverage() { + return osmxb.getSystemLoadAverage(); + } + + /** + * @return system cpu cores num + */ + public static int getCpuCores() { + return osmxb.getAvailableProcessors(); + } + + /** + * Get containers by hostname of address + * + * @param containers container list + * @param containerHosts container hostname or address set + * @return matched containers + */ + public static List getContainersByHostname( + List containers, + Collection containerHosts) { + + return containers.stream() + .filter(container -> + containerHosts.contains(container.getHostname()) || + containerHosts.contains(container.getAddress())) + .collect(Collectors.toList()); + } + + /** + * Get container by hostname + * + * @param hostName container hostname + * @return container + */ + public static Optional getContainerByHostname( + List containers, + String hostName) { + return containers.stream() + .filter(container -> container.getHostname().equals(hostName) || + container.getAddress().equals(hostName)) + .findFirst(); + } + + /** + * Get container by id + * + * @param containerID container id + * @return container + */ + public static Optional getContainerById( + List containers, + ContainerId containerID) { + return containers.stream() + .filter(container -> container.getId().equals(containerID)) + .findFirst(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java new file mode 100644 index 000000000..420215df1 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/util/Serializer.java @@ -0,0 +1,15 @@ +package io.ray.streaming.runtime.util; + +import io.ray.runtime.serializer.FstSerializer; + +public class Serializer { + + public static byte[] encode(Object obj) { + return FstSerializer.encode(obj); + } + + public static T decode(byte[] bytes) { + return FstSerializer.decode(bytes); + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java index 26a71453b..7aac6b0c6 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/JobWorker.java @@ -1,20 +1,32 @@ package io.ray.streaming.runtime.worker; +import io.ray.api.Ray; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.config.types.TransferChannelType; +import io.ray.streaming.runtime.context.ContextBackend; +import io.ray.streaming.runtime.context.ContextBackendFactory; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.core.processor.OneInputProcessor; import io.ray.streaming.runtime.core.processor.ProcessBuilder; import io.ray.streaming.runtime.core.processor.SourceProcessor; import io.ray.streaming.runtime.core.processor.StreamProcessor; import io.ray.streaming.runtime.master.JobMaster; +import io.ray.streaming.runtime.master.coordinator.command.WorkerRollbackRequest; +import io.ray.streaming.runtime.message.CallResult; +import io.ray.streaming.runtime.rpc.RemoteCallMaster; import io.ray.streaming.runtime.transfer.TransferHandler; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo.ChannelCreationStatus; +import io.ray.streaming.runtime.util.CheckpointStateUtil; import io.ray.streaming.runtime.util.EnvUtil; +import io.ray.streaming.runtime.util.Serializer; import io.ray.streaming.runtime.worker.context.JobWorkerContext; import io.ray.streaming.runtime.worker.tasks.OneInputStreamTask; import io.ray.streaming.runtime.worker.tasks.SourceStreamTask; import io.ray.streaming.runtime.worker.tasks.StreamTask; import java.io.Serializable; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,90 +48,223 @@ public class JobWorker implements Serializable { EnvUtil.loadNativeLibraries(); } + public final Object initialStateChangeLock = new Object(); + /** + * isRecreate=true means this worker is initialized more than once after actor created. + */ + public AtomicBoolean isRecreate = new AtomicBoolean(false); + public ContextBackend contextBackend; private JobWorkerContext workerContext; private ExecutionVertex executionVertex; private StreamingWorkerConfig workerConfig; - + /** + * The while-loop thread to read message, process message, and write results + */ private StreamTask task; + /** + * transferHandler handles messages by ray direct call + */ private TransferHandler transferHandler; + /** + * A flag to avoid duplicated rollback. Becomes true after requesting + * rollback, set to false when finish rollback. + */ + private boolean isNeedRollback = false; + private int rollbackCount = 0; - public JobWorker() { - LOG.info("Creating job worker succeeded."); + public JobWorker(ExecutionVertex executionVertex) { + LOG.info("Creating job worker."); + + // TODO: the following 3 lines is duplicated with that in init(), try to optimise it later. + this.executionVertex = executionVertex; + this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); + this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig); + + LOG.info("Ray.getRuntimeContext().wasCurrentActorRestarted()={}", + Ray.getRuntimeContext().wasCurrentActorRestarted()); + if (!Ray.getRuntimeContext().wasCurrentActorRestarted()) { + saveContext(); + LOG.info("Job worker is fresh started, init success."); + return; + } + + LOG.info("Begin load job worker checkpoint state."); + + byte[] bytes = CheckpointStateUtil.get(contextBackend, getJobWorkerContextKey()); + if (bytes != null) { + JobWorkerContext context = Serializer.decode(bytes); + LOG.info("Worker recover from checkpoint state, byte len={}, context={}.", bytes.length, + context); + init(context); + requestRollback("LoadCheckpoint request rollback in new actor."); + } else { + LOG.error( + "Worker is reconstructed, but can't load checkpoint. " + + "Check whether you checkpoint state is reliable. Current checkpoint state is {}.", + contextBackend.getClass().getName()); + } + } + + public synchronized void saveContext() { + byte[] contextBytes = Serializer.encode(workerContext); + String key = getJobWorkerContextKey(); + LOG.info("Saving context, worker context={}, serialized byte length={}, key={}.", workerContext, + contextBytes.length, key); + CheckpointStateUtil.put(contextBackend, key, contextBytes); } /** * Initialize JobWorker and data communication pipeline. */ public Boolean init(JobWorkerContext workerContext) { - LOG.info("Initiating job worker: {}. Worker context is: {}.", - workerContext.getWorkerName(), workerContext); + // IMPORTANT: some test cases depends on this log to find workers' pid, + // be careful when changing this log. + LOG.info("Initiating job worker: {}. Worker context is: {}, pid={}.", + workerContext.getWorkerName(), workerContext, EnvUtil.getJvmPid()); + + this.workerContext = workerContext; + this.executionVertex = workerContext.getExecutionVertex(); + this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); + // init state backend + this.contextBackend = ContextBackendFactory.getContextBackend(this.workerConfig); + + LOG.info("Initiating job worker succeeded: {}.", workerContext.getWorkerName()); + saveContext(); + return true; + } + + /** + * Start worker's stream tasks with specific checkpoint ID. + * + * @return a {@link CallResult} with {@link ChannelRecoverInfo}, + * contains {@link ChannelCreationStatus} of each input queue. + */ + public CallResult rollback(Long checkpointId, Long startRollbackTs) { + synchronized (initialStateChangeLock) { + if (task != null && task.isAlive() && checkpointId == task.lastCheckpointId && + task.isInitialState) { + return CallResult.skipped("Task is already in initial state, skip this rollback."); + } + } + long remoteCallCost = System.currentTimeMillis() - startRollbackTs; + + LOG.info("Start rollback[{}], checkpoint is {}, remote call cost {}ms.", + executionVertex.getExecutionJobVertexName(), checkpointId, remoteCallCost); + + rollbackCount++; + if (rollbackCount > 1) { + isRecreate.set(true); + } try { - this.workerContext = workerContext; - this.executionVertex = workerContext.getExecutionVertex(); - this.workerConfig = new StreamingWorkerConfig(executionVertex.getWorkerConfig()); - //Init transfer TransferChannelType channelType = workerConfig.transferConfig.channelType(); if (TransferChannelType.NATIVE_CHANNEL == channelType) { transferHandler = new TransferHandler(); } - // create stream task - task = createStreamTask(); - if (task == null) { - return false; + if (task != null) { + // make sure the task is closed + task.close(); + task = null; } - } catch (Exception e) { - LOG.error("Failed to initiate job worker.", e); - return false; - } - LOG.info("Initiating job worker succeeded: {}.", workerContext.getWorkerName()); - return true; - } - /** - * Start worker's stream tasks. - * - * @return result - */ - public Boolean start() { - try { - task.start(); + // create stream task + task = createStreamTask(checkpointId); + ChannelRecoverInfo channelRecoverInfo = task.recover(isRecreate.get()); + isNeedRollback = false; + + LOG.info("Rollback job worker success, checkpoint is {}, channelRecoverInfo is {}.", + checkpointId, channelRecoverInfo); + + return CallResult.success(channelRecoverInfo); } catch (Exception e) { - LOG.error("Start worker [{}] occur error.", executionVertex.getExecutionVertexName(), e); - return false; + LOG.error("Rollback job worker has exception.", e); + return CallResult.fail(ExceptionUtils.getStackTrace(e)); } - return true; } /** * Create tasks based on the processor corresponding of the operator. */ - private StreamTask createStreamTask() { - StreamTask task = null; + private StreamTask createStreamTask(long checkpointId) { + StreamTask task; StreamProcessor streamProcessor = ProcessBuilder .buildProcessor(executionVertex.getStreamOperator()); LOG.debug("Stream processor created: {}.", streamProcessor); - try { - if (streamProcessor instanceof SourceProcessor) { - task = new SourceStreamTask(getTaskId(), streamProcessor, this); - } else if (streamProcessor instanceof OneInputProcessor) { - task = new OneInputStreamTask(getTaskId(), streamProcessor, this); - } else { - throw new RuntimeException("Unsupported processor type:" + streamProcessor); - } - } catch (Exception e) { - LOG.info("Failed to create stream task.", e); - return task; + if (streamProcessor instanceof SourceProcessor) { + task = new SourceStreamTask(streamProcessor, this, checkpointId); + } else if (streamProcessor instanceof OneInputProcessor) { + task = new OneInputStreamTask(streamProcessor, this, checkpointId); + } else { + throw new RuntimeException("Unsupported processor type:" + streamProcessor); } LOG.info("Stream task created: {}.", task); return task; } - public int getTaskId() { - return executionVertex.getExecutionVertexId(); + // ---------------------------------------------------------------------- + // Checkpoint + // ---------------------------------------------------------------------- + + /** + * Trigger source job worker checkpoint + */ + public Boolean triggerCheckpoint(Long barrierId) { + LOG.info("Receive trigger, barrierId is {}.", barrierId); + if (task != null) { + return task.triggerCheckpoint(barrierId); + } + return false; + } + + public Boolean notifyCheckpointTimeout(Long checkpointId) { + LOG.info("Notify checkpoint timeout, checkpoint id is {}.", checkpointId); + if (task != null) { + task.notifyCheckpointTimeout(checkpointId); + } + return true; + } + + public Boolean clearExpiredCheckpoint(Long expiredStateCpId, Long expiredQueueCpId) { + LOG.info("Clear expired checkpoint state, checkpoint id is {}; " + + "Clear expired queue msg, checkpoint id is {}", + expiredStateCpId, expiredQueueCpId); + if (task != null) { + if (expiredStateCpId > 0) { + task.clearExpiredCpState(expiredStateCpId); + } + task.clearExpiredQueueMsg(expiredQueueCpId); + } + return true; + } + + // ---------------------------------------------------------------------- + // Failover + // ---------------------------------------------------------------------- + public void requestRollback(String exceptionMsg) { + LOG.info("Request rollback."); + isNeedRollback = true; + isRecreate.set(true); + boolean requestRet = RemoteCallMaster.requestJobWorkerRollback( + workerContext.getMaster(), new WorkerRollbackRequest( + workerContext.getWorkerActorId(), + exceptionMsg, + EnvUtil.getHostName(), + EnvUtil.getJvmPid() + )); + if (!requestRet) { + LOG.warn("Job worker request rollback failed! exceptionMsg={}.", exceptionMsg); + } + } + + public Boolean checkIfNeedRollback(Long startCallTs) { + // No save checkpoint in this query. + long remoteCallCost = System.currentTimeMillis() - startCallTs; + LOG.info("Finished checking if need to rollback with result: {}, rpc delay={}ms.", + isNeedRollback, remoteCallCost); + return isNeedRollback; } public StreamingWorkerConfig getWorkerConfig() { @@ -138,11 +283,19 @@ public class JobWorker implements Serializable { return task; } + private String getJobWorkerContextKey() { + return workerConfig.checkpointConfig.jobWorkerContextCpPrefixKey() + + workerConfig.commonConfig.jobName() + + "_" + executionVertex.getExecutionVertexId(); + } + /** * Used by upstream streaming queue to send data to this actor */ public void onReaderMessage(byte[] buffer) { - transferHandler.onReaderMessage(buffer); + if (transferHandler != null) { + transferHandler.onReaderMessage(buffer); + } } /** @@ -159,7 +312,9 @@ public class JobWorker implements Serializable { * Used by downstream streaming queue to send data to this actor */ public void onWriterMessage(byte[] buffer) { - transferHandler.onWriterMessage(buffer); + if (transferHandler != null) { + transferHandler.onWriterMessage(buffer); + } } /** @@ -172,4 +327,5 @@ public class JobWorker implements Serializable { } return transferHandler.onWriterMessageSync(buffer); } + } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java index 495a2b187..e4fd3b992 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/context/JobWorkerContext.java @@ -3,7 +3,9 @@ package io.ray.streaming.runtime.worker.context; import com.google.common.base.MoreObjects; import com.google.protobuf.ByteString; import io.ray.api.ActorHandle; +import io.ray.api.id.ActorId; import io.ray.runtime.actor.NativeActorHandle; +import io.ray.streaming.runtime.config.global.CommonConfig; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.master.JobMaster; @@ -33,6 +35,10 @@ public class JobWorkerContext implements Serializable { this.executionVertex = executionVertex; } + public ActorId getWorkerActorId() { + return executionVertex.getWorkerActorId(); + } + public int getWorkerId() { return executionVertex.getExecutionVertexId(); } @@ -53,6 +59,14 @@ public class JobWorkerContext implements Serializable { return executionVertex; } + public Map getConf() { + return getExecutionVertex().getWorkerConfig(); + } + + public String getJobName() { + return getConf().get(CommonConfig.JOB_NAME); + } + @Override public String toString() { return MoreObjects.toStringHelper(this) diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java index 9ce0c5fb7..eeddf13e5 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/InputStreamTask.java @@ -2,22 +2,31 @@ package io.ray.streaming.runtime.worker.tasks; import com.google.common.base.MoreObjects; import io.ray.streaming.runtime.core.processor.Processor; +import io.ray.streaming.runtime.generated.RemoteCall; import io.ray.streaming.runtime.serialization.CrossLangSerializer; import io.ray.streaming.runtime.serialization.JavaSerializer; import io.ray.streaming.runtime.serialization.Serializer; -import io.ray.streaming.runtime.transfer.Message; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException; +import io.ray.streaming.runtime.transfer.message.BarrierMessage; +import io.ray.streaming.runtime.transfer.message.ChannelMessage; +import io.ray.streaming.runtime.transfer.message.DataMessage; import io.ray.streaming.runtime.worker.JobWorker; +import java.util.Map; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public abstract class InputStreamTask extends StreamTask { - private volatile boolean running = true; - private volatile boolean stopped = false; - private long readTimeoutMillis; + private static final Logger LOG = LoggerFactory.getLogger(InputStreamTask.class); + private final io.ray.streaming.runtime.serialization.Serializer javaSerializer; private final io.ray.streaming.runtime.serialization.Serializer crossLangSerializer; + private final long readTimeoutMillis; - public InputStreamTask(int taskId, Processor processor, JobWorker jobWorker) { - super(taskId, processor, jobWorker); + public InputStreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) { + super(processor, jobWorker, lastCheckpointId); readTimeoutMillis = jobWorker.getWorkerConfig().transferConfig.readerTimerIntervalMs(); javaSerializer = new JavaSerializer(); crossLangSerializer = new CrossLangSerializer(); @@ -29,35 +38,64 @@ public abstract class InputStreamTask extends StreamTask { @Override public void run() { - while (running) { - Message item = reader.read(readTimeoutMillis); - if (item != null) { - byte[] bytes = new byte[item.body().remaining() - 1]; - byte typeId = item.body().get(); - item.body().get(bytes); - Object obj; - if (typeId == Serializer.JAVA_TYPE_ID) { - obj = javaSerializer.deserialize(bytes); - } else { - obj = crossLangSerializer.deserialize(bytes); + try { + while (running) { + ChannelMessage item; + + // reader.read() will change the consumer state once it got an item. This lock is to + // ensure worker can get correct isInitialState value in exactly-once-mode's rollback. + synchronized (jobWorker.initialStateChangeLock) { + item = reader.read(readTimeoutMillis); + if (item != null) { + isInitialState = false; + } else { + continue; + } } - processor.process(obj); + + if (item instanceof DataMessage) { + DataMessage dataMessage = (DataMessage) item; + byte[] bytes = new byte[dataMessage.body().remaining() - 1]; + byte typeId = dataMessage.body().get(); + dataMessage.body().get(bytes); + Object obj; + if (typeId == Serializer.JAVA_TYPE_ID) { + obj = javaSerializer.deserialize(bytes); + } else { + obj = crossLangSerializer.deserialize(bytes); + } + processor.process(obj); + } else if (item instanceof BarrierMessage) { + final BarrierMessage queueBarrier = (BarrierMessage) item; + byte[] barrierData = new byte[queueBarrier.getData().remaining()]; + queueBarrier.getData().get(barrierData); + RemoteCall.Barrier barrierPb = RemoteCall.Barrier.parseFrom(barrierData); + final long checkpointId = barrierPb.getId(); + LOG.info("Start to do checkpoint {}, worker name is {}.", checkpointId, + jobWorker.getWorkerContext().getWorkerName()); + + final Map inputPoints = queueBarrier.getInputOffsets(); + doCheckpoint(checkpointId, inputPoints); + LOG.info("Do checkpoint {} success.", checkpointId); + } + } + } catch (Throwable throwable) { + if (throwable instanceof ChannelInterruptException || + ExceptionUtils.getRootCause(throwable) instanceof ChannelInterruptException) { + LOG.info("queue has stopped."); + } else { + // error occurred, need to rollback + LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, throwable); + requestRollback(ExceptionUtils.getStackTrace(throwable)); } } + LOG.info("Input stream task thread exit."); stopped = true; } - @Override - protected void cancelTask() throws Exception { - running = false; - while (!stopped) { - } - } - @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("taskId", taskId) .add("processor", processor) .toString(); } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java index 16293f9ae..8eaf2ef66 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java @@ -8,7 +8,7 @@ import io.ray.streaming.runtime.worker.JobWorker; */ public class OneInputStreamTask extends InputStreamTask { - public OneInputStreamTask(int taskId, Processor inputProcessor, JobWorker jobWorker) { - super(taskId, inputProcessor, jobWorker); + public OneInputStreamTask(Processor inputProcessor, JobWorker jobWorker, long lastCheckpointId) { + super(inputProcessor, jobWorker, lastCheckpointId); } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java index 3c70ece44..9fc94c06d 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/SourceStreamTask.java @@ -3,7 +3,10 @@ package io.ray.streaming.runtime.worker.tasks; import io.ray.streaming.operator.SourceOperator; import io.ray.streaming.runtime.core.processor.Processor; import io.ray.streaming.runtime.core.processor.SourceProcessor; +import io.ray.streaming.runtime.transfer.exception.ChannelInterruptException; import io.ray.streaming.runtime.worker.JobWorker; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -13,12 +16,19 @@ public class SourceStreamTask extends StreamTask { private final SourceProcessor sourceProcessor; + + /** + * The pending barrier ID to be triggered. + */ + private final AtomicReference pendingBarrier = new AtomicReference<>(); + private long lastCheckpointId = 0; + /** * SourceStreamTask for executing a {@link SourceOperator}. It is responsible for running the * corresponding source operator. */ - public SourceStreamTask(int taskId, Processor sourceProcessor, JobWorker jobWorker) { - super(taskId, sourceProcessor, jobWorker); + public SourceStreamTask(Processor sourceProcessor, JobWorker jobWorker, long lastCheckpointId) { + super(sourceProcessor, jobWorker, lastCheckpointId); this.sourceProcessor = (SourceProcessor) processor; } @@ -29,12 +39,48 @@ public class SourceStreamTask extends StreamTask { @Override public void run() { LOG.info("Source stream task thread start."); + Long barrierId; + try { + while (running) { + isInitialState = false; - sourceProcessor.run(); + // check checkpoint + barrierId = pendingBarrier.get(); + if (barrierId != null) { + // Important: because cp maybe timeout, master will use the old checkpoint id again + if (pendingBarrier.compareAndSet(barrierId, null)) { + // source fetcher only have outputPoints + LOG.info("Start to do checkpoint {}, worker name is {}.", + barrierId, jobWorker.getWorkerContext().getWorkerName()); + + doCheckpoint(barrierId, null); + + LOG.info("Finish to do checkpoint {}.", barrierId); + } else { + // pendingCheckpointId has modify, should not happen + LOG.warn("Pending checkpointId modify unexpected, expect={}, now={}.", barrierId, + pendingBarrier.get()); + } + } + + sourceProcessor.fetch(); + } + } catch (Throwable e) { + if (e instanceof ChannelInterruptException || + ExceptionUtils.getRootCause(e) instanceof ChannelInterruptException) { + LOG.info("queue has stopped."); + } else { + // occur error, need to rollback + LOG.error("Last success checkpointId={}, now occur error.", lastCheckpointId, e); + requestRollback(ExceptionUtils.getStackTrace(e)); + } + } + + LOG.info("Source stream task thread exit."); } @Override - protected void cancelTask() { + public boolean triggerCheckpoint(Long barrierId) { + return pendingBarrier.compareAndSet(null, barrierId); } - } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java index 79ad0100d..78ef0dbd4 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/StreamTask.java @@ -6,53 +6,103 @@ import io.ray.streaming.api.collector.Collector; import io.ray.streaming.api.context.RuntimeContext; import io.ray.streaming.api.partition.Partition; import io.ray.streaming.runtime.config.worker.WorkerInternalConfig; +import io.ray.streaming.runtime.context.ContextBackend; +import io.ray.streaming.runtime.context.OperatorCheckpointInfo; import io.ray.streaming.runtime.core.collector.OutputCollector; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionEdge; +import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.core.processor.Processor; -import io.ray.streaming.runtime.transfer.ChannelId; +import io.ray.streaming.runtime.generated.RemoteCall; +import io.ray.streaming.runtime.master.coordinator.command.WorkerCommitReport; +import io.ray.streaming.runtime.rpc.RemoteCallMaster; import io.ray.streaming.runtime.transfer.DataReader; import io.ray.streaming.runtime.transfer.DataWriter; +import io.ray.streaming.runtime.transfer.channel.ChannelRecoverInfo; +import io.ray.streaming.runtime.transfer.channel.OffsetInfo; +import io.ray.streaming.runtime.util.CheckpointStateUtil; +import io.ray.streaming.runtime.util.Serializer; import io.ray.streaming.runtime.worker.JobWorker; +import io.ray.streaming.runtime.worker.context.JobWorkerContext; import io.ray.streaming.runtime.worker.context.StreamingRuntimeContext; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * {@link StreamTask} is a while-loop thread to read message, process message, and send result + * messages to downstream operators + */ public abstract class StreamTask implements Runnable { private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); - - protected int taskId; + private final ContextBackend checkpointState; + public volatile boolean isInitialState = true; + public long lastCheckpointId; protected Processor processor; protected JobWorker jobWorker; protected DataReader reader; - List collectors = new ArrayList<>(); - + protected DataWriter writer; protected volatile boolean running = true; protected volatile boolean stopped = false; - + List collectors = new ArrayList<>(); + private Set outdatedCheckpoints = new HashSet<>(); private Thread thread; - protected StreamTask(int taskId, Processor processor, JobWorker jobWorker) { - this.taskId = taskId; + protected StreamTask(Processor processor, JobWorker jobWorker, long lastCheckpointId) { this.processor = processor; this.jobWorker = jobWorker; - prepareTask(); + this.checkpointState = jobWorker.contextBackend; + this.lastCheckpointId = lastCheckpointId; this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName() + "-" + System.currentTimeMillis()); this.thread.setDaemon(true); } + public ChannelRecoverInfo recover(boolean isRecover) { + + if (isRecover) { + LOG.info("Stream task begin recover."); + } else { + LOG.info("Stream task first start begin."); + } + prepareTask(isRecover); + + // start runner + ChannelRecoverInfo recoverInfo = new ChannelRecoverInfo(new HashMap<>()); + if (reader != null) { + recoverInfo = reader.getQueueRecoverInfo(); + } + + thread.setUncaughtExceptionHandler( + (t, e) -> LOG.error("Uncaught exception in runner thread.", e)); + LOG.info("Start stream task: {}.", this.getClass().getSimpleName()); + thread.start(); + + if (isRecover) { + LOG.info("Stream task recover end."); + } else { + LOG.info("Stream task first start finished."); + } + + return recoverInfo; + } + /** - * Build upstream and downstream data transmission channels according to {@link ExecutionVertex}. + * Load checkpoint and build upstream and downstream data transmission + * channels according to {@link ExecutionVertex}. */ - private void prepareTask() { - LOG.debug("Preparing stream task."); + private void prepareTask(boolean isRecreate) { + LOG.info("Preparing stream task, isRecreate={}.", isRecreate); ExecutionVertex executionVertex = jobWorker.getExecutionVertex(); // set vertex info into config for native using @@ -61,73 +111,92 @@ public abstract class StreamTask implements Runnable { jobWorker.getWorkerConfig().workerInternalConfig.setProperty( WorkerInternalConfig.OP_NAME_INTERNAL, executionVertex.getExecutionJobVertexName()); - // producer + OperatorCheckpointInfo operatorCheckpointInfo = new OperatorCheckpointInfo(); + byte[] bytes = null; + // Fetch checkpoint from storage only in recreate mode not for new startup worker + // in rescaling or something like that. + if (isRecreate) { + String cpKey = genOpCheckpointKey(lastCheckpointId); + LOG.info("Getting task checkpoints from state, cpKey={}, checkpointId={}.", cpKey, + lastCheckpointId); + bytes = CheckpointStateUtil.get(checkpointState, cpKey); + if (bytes == null) { + String msg = String.format("Task recover failed, checkpoint is null! cpKey=%s", cpKey); + throw new RuntimeException(msg); + } + } + + // when use memory state, if actor throw exception, will miss state + if (bytes != null) { + operatorCheckpointInfo = Serializer.decode(bytes); + processor.loadCheckpoint(operatorCheckpointInfo.processorCheckpoint); + LOG.info( + "Stream task recover from checkpoint state, checkpoint bytes len={}, checkpointInfo={}.", + bytes.length, operatorCheckpointInfo); + } + + // writer + if (!executionVertex.getOutputEdges().isEmpty()) { + LOG.info("Register queue writer, channels={}, outputCheckpoints={}.", + executionVertex.getOutputChannelIdList(), operatorCheckpointInfo.outputPoints); + writer = new DataWriter( + executionVertex.getOutputChannelIdList(), + executionVertex.getOutputActorList(), + operatorCheckpointInfo.outputPoints, + jobWorker.getWorkerConfig() + ); + } + + // reader + if (!executionVertex.getInputEdges().isEmpty()) { + LOG.info("Register queue reader, channels={}, inputCheckpoints={}.", + executionVertex.getInputChannelIdList(), operatorCheckpointInfo.inputPoints); + reader = new DataReader( + executionVertex.getInputChannelIdList(), + executionVertex.getInputActorList(), + operatorCheckpointInfo.inputPoints, + jobWorker.getWorkerConfig() + ); + } + + openProcessor(); + + LOG.debug("Finished preparing stream task."); + } + + /** + * Create one collector for each distinct output operator(i.e. each {@link ExecutionJobVertex}) + */ + private void openProcessor() { + ExecutionVertex executionVertex = jobWorker.getExecutionVertex(); List outputEdges = executionVertex.getOutputEdges(); - // merge all output edges to create writer - List outputChannelIds = new ArrayList<>(); - List targetActors = new ArrayList<>(); - - for (ExecutionEdge edge : outputEdges) { - String channelId = ChannelId.genIdStr( - taskId, - edge.getTargetExecutionVertex().getExecutionVertexId(), - executionVertex.getBuildTime()); - outputChannelIds.add(channelId); - targetActors.add(edge.getTargetExecutionVertex().getWorkerActor()); - } - - if (!targetActors.isEmpty()) { - DataWriter writer = new DataWriter( - outputChannelIds, targetActors, jobWorker.getWorkerConfig() - ); - - // create a collector for each output operator - Map> opGroupedChannelId = new HashMap<>(); - Map> opGroupedActor = new HashMap<>(); - Map opPartitionMap = new HashMap<>(); - for (int i = 0; i < outputEdges.size(); ++i) { - ExecutionEdge edge = outputEdges.get(i); - String opName = edge.getTargetExecutionJobVertexName(); - if (!opPartitionMap.containsKey(opName)) { - opGroupedChannelId.put(opName, new ArrayList<>()); - opGroupedActor.put(opName, new ArrayList<>()); - } - opGroupedChannelId.get(opName).add(outputChannelIds.get(i)); - opGroupedActor.get(opName).add(targetActors.get(i)); - opPartitionMap.put(opName, edge.getPartition()); + Map> opGroupedChannelId = new HashMap<>(); + Map> opGroupedActor = new HashMap<>(); + Map opPartitionMap = new HashMap<>(); + for (int i = 0; i < outputEdges.size(); ++i) { + ExecutionEdge edge = outputEdges.get(i); + String opName = edge.getTargetExecutionJobVertexName(); + if (!opPartitionMap.containsKey(opName)) { + opGroupedChannelId.put(opName, new ArrayList<>()); + opGroupedActor.put(opName, new ArrayList<>()); } - opPartitionMap.keySet().forEach(opName -> { - collectors.add(new OutputCollector( - writer, opGroupedChannelId.get(opName), - opGroupedActor.get(opName), opPartitionMap.get(opName) - )); - }); - } - - // consumer - List inputEdges = executionVertex.getInputEdges(); - List inputChannelIds = new ArrayList<>(); - List inputActors = new ArrayList<>(); - for (ExecutionEdge edge : inputEdges) { - String queueName = ChannelId.genIdStr( - edge.getSourceExecutionVertex().getExecutionVertexId(), - taskId, - executionVertex.getBuildTime()); - inputChannelIds.add(queueName); - inputActors.add(edge.getSourceExecutionVertex().getWorkerActor()); - } - if (!inputActors.isEmpty()) { - LOG.info("Register queue consumer, channels {}.", inputChannelIds); - reader = new DataReader(inputChannelIds, inputActors, jobWorker.getWorkerConfig()); + opGroupedChannelId.get(opName).add(executionVertex.getOutputChannelIdList().get(i)); + opGroupedActor.get(opName).add(executionVertex.getOutputActorList().get(i)); + opPartitionMap.put(opName, edge.getPartition()); } + opPartitionMap.keySet().forEach(opName -> { + collectors.add(new OutputCollector( + writer, opGroupedChannelId.get(opName), + opGroupedActor.get(opName), opPartitionMap.get(opName) + )); + }); RuntimeContext runtimeContext = new StreamingRuntimeContext(executionVertex, jobWorker.getWorkerConfig().configMap, executionVertex.getParallelism()); processor.open(collectors, runtimeContext); - LOG.debug("Finished preparing stream task."); } /** @@ -135,16 +204,6 @@ public abstract class StreamTask implements Runnable { */ protected abstract void init() throws Exception; - /** - * Stop running tasks. - */ - protected abstract void cancelTask() throws Exception; - - public void start() { - LOG.info("Start stream task: {}-{}", this.getClass().getSimpleName(), taskId); - this.thread.start(); - } - /** * Close running tasks. */ @@ -159,4 +218,134 @@ public abstract class StreamTask implements Runnable { LOG.info("Stream task close success."); } + // ---------------------------------------------------------------------- + // Checkpoint + // ---------------------------------------------------------------------- + + public boolean triggerCheckpoint(Long barrierId) { + throw new UnsupportedOperationException("Only source operator supports trigger checkpoints."); + } + + public void doCheckpoint(long checkpointId, Map inputPoints) { + Map outputPoints = null; + if (writer != null) { + outputPoints = writer.getOutputCheckpoints(); + RemoteCall.Barrier barrierPb = + RemoteCall.Barrier.newBuilder().setId(checkpointId).build(); + ByteBuffer byteBuffer = ByteBuffer.wrap(barrierPb.toByteArray()); + byteBuffer.order(ByteOrder.nativeOrder()); + writer.broadcastBarrier(checkpointId, byteBuffer); + } + + LOG.info("Start do checkpoint, cp id={}, inputPoints={}, outputPoints={}.", checkpointId, + inputPoints, outputPoints); + + this.lastCheckpointId = checkpointId; + Serializable processorCheckpoint = processor.saveCheckpoint(); + + try { + OperatorCheckpointInfo opCpInfo = + new OperatorCheckpointInfo(inputPoints, outputPoints, processorCheckpoint, + checkpointId); + saveCpStateAndReport(opCpInfo, checkpointId); + } catch (Exception e) { + // there will be exceptions when flush state to backend. + // we ignore the exception to prevent failover + LOG.error("Processor or op checkpoint exception.", e); + } + + LOG.info("Operator do checkpoint {} finish.", checkpointId); + } + + private void saveCpStateAndReport( + OperatorCheckpointInfo operatorCheckpointInfo, + long checkpointId) { + saveCp(operatorCheckpointInfo, checkpointId); + reportCommit(checkpointId); + + LOG.info("Finish save cp state and report, checkpoint id is {}.", checkpointId); + } + + private void saveCp(OperatorCheckpointInfo operatorCheckpointInfo, long checkpointId) { + byte[] bytes = Serializer.encode(operatorCheckpointInfo); + String cpKey = genOpCheckpointKey(checkpointId); + LOG.info("Saving task checkpoint, cpKey={}, byte len={}, checkpointInfo={}.", cpKey, + bytes.length, operatorCheckpointInfo); + synchronized (checkpointState) { + if (outdatedCheckpoints.contains(checkpointId)) { + LOG.info("Outdated checkpoint, skip save checkpoint."); + outdatedCheckpoints.remove(checkpointId); + } else { + CheckpointStateUtil.put(checkpointState, cpKey, bytes); + } + } + } + + private void reportCommit(long checkpointId) { + final JobWorkerContext context = jobWorker.getWorkerContext(); + LOG.info("Report commit async, checkpoint id {}.", checkpointId); + RemoteCallMaster.reportJobWorkerCommitAsync(context.getMaster(), + new WorkerCommitReport(context.getWorkerActorId(), checkpointId)); + } + + public void notifyCheckpointTimeout(long checkpointId) { + String cpKey = genOpCheckpointKey(checkpointId); + try { + synchronized (checkpointState) { + if (checkpointState.exists(cpKey)) { + checkpointState.remove(cpKey); + } else { + outdatedCheckpoints.add(checkpointId); + } + } + } catch (Exception e) { + LOG.error("Notify checkpoint timeout failed, checkpointId is {}.", checkpointId, e); + } + } + + public void clearExpiredCpState(long checkpointId) { + String cpKey = genOpCheckpointKey(checkpointId); + try { + checkpointState.remove(cpKey); + } catch (Exception e) { + LOG.error("Failed to remove key {} from state backend.", cpKey, e); + } + } + + public void clearExpiredQueueMsg(long checkpointId) { + // get operator checkpoint + String cpKey = genOpCheckpointKey(checkpointId); + byte[] bytes; + try { + bytes = checkpointState.get(cpKey); + } catch (Exception e) { + LOG.error("Failed to get key {} from state backend.", cpKey, e); + return; + } + if (bytes != null) { + final OperatorCheckpointInfo operatorCheckpointInfo = Serializer.decode(bytes); + long cpId = operatorCheckpointInfo.checkpointId; + if (writer != null) { + writer.clearCheckpoint(cpId); + } + } + } + + public String genOpCheckpointKey(long checkpointId) { + // TODO: need to support job restart and actorId changed + final JobWorkerContext context = jobWorker.getWorkerContext(); + return jobWorker.getWorkerConfig().checkpointConfig.jobWorkerOpCpPrefixKey() + + context.getJobName() + "_" + context.getWorkerName() + "_" + checkpointId; + } + + // ---------------------------------------------------------------------- + // Failover + // ---------------------------------------------------------------------- + protected void requestRollback(String exceptionMsg) { + jobWorker.requestRollback(exceptionMsg); + } + + public boolean isAlive() { + return this.thread.isAlive(); + } } diff --git a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java index 1bba5b0f5..40870f51a 100644 --- a/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java +++ b/streaming/java/streaming-runtime/src/main/java/io/ray/streaming/runtime/worker/tasks/TwoInputStreamTask.java @@ -10,12 +10,12 @@ import io.ray.streaming.runtime.worker.JobWorker; public class TwoInputStreamTask extends InputStreamTask { public TwoInputStreamTask( - int taskId, Processor processor, JobWorker jobWorker, String leftStream, - String rightStream) { - super(taskId, processor, jobWorker); + String rightStream, + long lastCheckpointId) { + super(processor, jobWorker, lastCheckpointId); ((TwoInputProcessor) (super.processor)).setLeftStream(leftStream); ((TwoInputProcessor) (super.processor)).setRightStream(rightStream); } diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java index 05d4f9dc8..9af1899ac 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/core/graph/ExecutionGraphTest.java @@ -15,7 +15,7 @@ import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionJobVertex; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionVertex; import io.ray.streaming.runtime.core.resource.ResourceType; -import io.ray.streaming.runtime.master.JobRuntimeContext; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; import io.ray.streaming.runtime.master.graphmanager.GraphManager; import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; import java.util.HashMap; @@ -34,7 +34,7 @@ public class ExecutionGraphTest extends BaseUnitTest { public void testBuildExecutionGraph() { Map jobConf = new HashMap<>(); StreamingConfig streamingConfig = new StreamingConfig(jobConf); - GraphManager graphManager = new GraphManagerImpl(new JobRuntimeContext(streamingConfig)); + GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig)); JobGraph jobGraph = buildJobGraph(); jobGraph.getJobConfig().put("streaming.task.resource.cpu.limitation.enable", "true"); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java index 8473937a9..b7e2aef61 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/demo/UnionStreamTest.java @@ -36,7 +36,7 @@ public class UnionStreamTest { streamSource1 .union(streamSource2, streamSource3) .sink((SinkFunction) value -> { - LOG.info("UnionStreamTest: {}", value); + LOG.info("UnionStreamTest, sink: {}", value); try { if (!Files.exists(Paths.get(sinkFileName))) { Files.createFile(Paths.get(sinkFileName)); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java index 53f3cc4d1..76658e1ea 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/JobMasterTest.java @@ -14,7 +14,7 @@ public class JobMasterTest { Assert.assertNull(jobMaster.getGraphManager()); Assert.assertNull(jobMaster.getResourceManager()); Assert.assertNull(jobMaster.getJobMasterActor()); - Assert.assertFalse(jobMaster.init()); + Assert.assertFalse(jobMaster.init(false)); } } diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java index 579b1266a..5f3e7db35 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/ResourceManagerTest.java @@ -7,7 +7,7 @@ import io.ray.streaming.runtime.BaseUnitTest; import io.ray.streaming.runtime.config.StreamingConfig; import io.ray.streaming.runtime.config.global.CommonConfig; import io.ray.streaming.runtime.core.resource.Container; -import io.ray.streaming.runtime.master.JobRuntimeContext; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; import io.ray.streaming.runtime.util.RayUtils; import java.util.HashMap; import java.util.List; @@ -44,8 +44,8 @@ public class ResourceManagerTest extends BaseUnitTest { Map conf = new HashMap(); conf.put(CommonConfig.JOB_NAME, "testApi"); StreamingConfig config = new StreamingConfig(conf); - JobRuntimeContext jobRuntimeContext = new JobRuntimeContext(config); - ResourceManager resourceManager = new ResourceManagerImpl(jobRuntimeContext); + JobMasterRuntimeContext jobMasterRuntimeContext = new JobMasterRuntimeContext(config); + ResourceManager resourceManager = new ResourceManagerImpl(jobMasterRuntimeContext); // test register container List containers = resourceManager.getRegisteredContainers(); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java index 4a8ea66b1..2e42e606b 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/master/resourcemanager/strategy/PipelineFirstStrategyTest.java @@ -9,7 +9,7 @@ import io.ray.streaming.runtime.core.graph.ExecutionGraphTest; import io.ray.streaming.runtime.core.graph.executiongraph.ExecutionGraph; import io.ray.streaming.runtime.core.resource.Container; import io.ray.streaming.runtime.core.resource.ResourceType; -import io.ray.streaming.runtime.master.JobRuntimeContext; +import io.ray.streaming.runtime.master.context.JobMasterRuntimeContext; import io.ray.streaming.runtime.master.graphmanager.GraphManager; import io.ray.streaming.runtime.master.graphmanager.GraphManagerImpl; import io.ray.streaming.runtime.master.resourcemanager.ResourceAssignmentView; @@ -64,7 +64,7 @@ public class PipelineFirstStrategyTest extends BaseUnitTest { Map jobConf = new HashMap<>(); StreamingConfig streamingConfig = new StreamingConfig(jobConf); - GraphManager graphManager = new GraphManagerImpl(new JobRuntimeContext(streamingConfig)); + GraphManager graphManager = new GraphManagerImpl(new JobMasterRuntimeContext(streamingConfig)); JobGraph jobGraph = ExecutionGraphTest.buildJobGraph(); ExecutionGraph executionGraph = ExecutionGraphTest.buildExecutionGraph(graphManager, jobGraph); ResourceAssignmentView assignmentView = strategy.assignResource(containers, executionGraph); diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java index b1760ceb6..879364e04 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java @@ -9,7 +9,7 @@ import io.ray.streaming.api.function.impl.FlatMapFunction; import io.ray.streaming.api.function.impl.ReduceFunction; import io.ray.streaming.api.stream.DataStreamSource; import io.ray.streaming.runtime.BaseUnitTest; -import io.ray.streaming.runtime.transfer.ChannelId; +import io.ray.streaming.runtime.transfer.channel.ChannelId; import io.ray.streaming.runtime.util.EnvUtil; import io.ray.streaming.util.Config; import java.io.File; diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java index 4e94d5167..1305267e2 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/streamingqueue/Worker.java @@ -6,11 +6,11 @@ import io.ray.api.Ray; import io.ray.runtime.functionmanager.JavaFunctionDescriptor; import io.ray.streaming.runtime.config.StreamingWorkerConfig; import io.ray.streaming.runtime.transfer.ChannelCreationParametersBuilder; -import io.ray.streaming.runtime.transfer.ChannelId; -import io.ray.streaming.runtime.transfer.DataMessage; import io.ray.streaming.runtime.transfer.DataReader; import io.ray.streaming.runtime.transfer.DataWriter; import io.ray.streaming.runtime.transfer.TransferHandler; +import io.ray.streaming.runtime.transfer.channel.ChannelId; +import io.ray.streaming.runtime.transfer.message.DataMessage; import io.ray.streaming.util.Config; import java.lang.management.ManagementFactory; import java.nio.ByteBuffer; @@ -104,7 +104,7 @@ class ReaderWorker extends Worker { new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"), new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessageSync", "([B)[B")); StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf); - dataReader = new DataReader(inputQueueList, inputActors, workerConfig); + dataReader = new DataReader(inputQueueList, inputActors, new HashMap<>(), workerConfig); // Should not GetBundle in RayCall thread Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() { @@ -124,7 +124,7 @@ class ReaderWorker extends Worker { int checkPointId = 1; for (int i = 0; i < msgCount * inputQueueList.size(); ++i) { - DataMessage dataMessage = dataReader.read(100); + DataMessage dataMessage = (DataMessage) dataReader.read(100); if (dataMessage == null) { LOGGER.error("dataMessage is null"); @@ -232,7 +232,7 @@ class WriterWorker extends Worker { new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessage", "([B)V"), new JavaFunctionDescriptor(Worker.class.getName(), "onReaderMessageSync", "([B)[B")); StreamingWorkerConfig workerConfig = new StreamingWorkerConfig(conf); - dataWriter = new DataWriter(outputQueueList, outputActors, workerConfig); + dataWriter = new DataWriter(outputQueueList, outputActors, new HashMap<>(), workerConfig); Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() { @Override public void run() { diff --git a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java index 11dcddeda..46270837e 100644 --- a/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java +++ b/streaming/java/streaming-runtime/src/test/java/io/ray/streaming/runtime/transfer/ChannelIdTest.java @@ -4,6 +4,7 @@ import static org.testng.Assert.assertEquals; import io.ray.streaming.runtime.BaseUnitTest; +import io.ray.streaming.runtime.transfer.channel.ChannelId; import io.ray.streaming.runtime.util.EnvUtil; import org.testng.annotations.Test; diff --git a/streaming/java/streaming-runtime/src/test/resources/log4j.properties b/streaming/java/streaming-runtime/src/test/resources/log4j.properties index 30d876aec..8d40bd190 100644 --- a/streaming/java/streaming-runtime/src/test/resources/log4j.properties +++ b/streaming/java/streaming-runtime/src/test/resources/log4j.properties @@ -3,4 +3,4 @@ log4j.rootLogger=INFO, stdout log4j.appender.stdout=org.apache.log4j.ConsoleAppender log4j.appender.stdout.Target=System.out log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SS} %-4p %c{1}:%L [%t] - %m%n diff --git a/streaming/python/collector.py b/streaming/python/collector.py index 12b6c096b..1760900fc 100644 --- a/streaming/python/collector.py +++ b/streaming/python/collector.py @@ -68,7 +68,7 @@ class OutputCollector(Collector): python_buffer = self.python_serializer.serialize(record) self._writer.write( self._channel_ids[partition_index], - serialization._PYTHON_TYPE_ID + python_buffer) + bytes([serialization.PYTHON_TYPE_ID]) + python_buffer) else: # avoid repeated serialization if cross_lang_buffer is None: @@ -76,4 +76,5 @@ class OutputCollector(Collector): record) self._writer.write( self._channel_ids[partition_index], - serialization._CROSS_LANG_TYPE_ID + cross_lang_buffer) + bytes([serialization.CROSS_LANG_TYPE_ID]) + + cross_lang_buffer) diff --git a/streaming/python/config.py b/streaming/python/config.py index d7af6230e..b80d49b29 100644 --- a/streaming/python/config.py +++ b/streaming/python/config.py @@ -8,7 +8,6 @@ class Config: NATIVE_CHANNEL = "native_channel" CHANNEL_SIZE = "channel_size" CHANNEL_SIZE_DEFAULT = 10**8 - IS_RECREATE = "streaming.is_recreate" # return from StreamingReader.getBundle if only empty message read in this # interval. TIMER_INTERVAL_MS = "timer_interval_ms" @@ -26,3 +25,38 @@ class Config: FLOW_CONTROL_TYPE = "streaming.flow_control_type" WRITER_CONSUMED_STEP = "streaming.writer.consumed_step" READER_CONSUMED_STEP = "streaming.reader.consumed_step" + + # state backend + CP_STATE_BACKEND_TYPE = "streaming.context-backend.type" + CP_STATE_BACKEND_MEMORY = "memory" + CP_STATE_BACKEND_LOCAL_FILE = "local_file" + CP_STATE_BACKEND_DEFAULT = CP_STATE_BACKEND_MEMORY + + # local disk + FILE_STATE_ROOT_PATH = "streaming.context-backend.file-state.root" + FILE_STATE_ROOT_PATH_DEFAULT = "/tmp/ray_streaming_state" + + # checkpoint + JOB_WORKER_CONTEXT_KEY = "jobworker_context_" + + # reliability level + REQUEST_ROLLBACK_RETRY_TIMES = 3 + + # checkpoint prefix key + JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY = "jobwk_op_" + + +class ConfigHelper(object): + @staticmethod + def get_cp_local_file_root_dir(conf): + value = conf.get(Config.FILE_STATE_ROOT_PATH) + if value is not None: + return value + return Config.FILE_STATE_ROOT_PATH_DEFAULT + + @staticmethod + def get_cp_context_backend_type(conf): + value = conf.get(Config.CP_STATE_BACKEND_TYPE) + if value is not None: + return value + return Config.CP_STATE_BACKEND_DEFAULT diff --git a/streaming/python/function.py b/streaming/python/function.py index c9cfedcc6..038132d90 100644 --- a/streaming/python/function.py +++ b/streaming/python/function.py @@ -22,6 +22,12 @@ class Function(ABC): def close(self): pass + def save_checkpoint(self): + pass + + def load_checkpoint(self, checkpoint_obj): + pass + class EmptyFunction(Function): """Default function which does nothing""" @@ -58,12 +64,15 @@ class SourceFunction(Function): pass @abstractmethod - def run(self, ctx: SourceContext): + def fetch(self, ctx: SourceContext): """Starts the source. Implementations can use the :class:`SourceContext` to emit elements. """ pass + def close(self): + pass + class MapFunction(Function): """ @@ -176,24 +185,29 @@ class CollectionSourceFunction(SourceFunction): def init(self, parallel, index): pass - def run(self, ctx: SourceContext): + def fetch(self, ctx: SourceContext): for v in self.values: ctx.collect(v) + self.values = [] class LocalFileSourceFunction(SourceFunction): def __init__(self, filename): self.filename = filename + self.done = False def init(self, parallel, index): pass - def run(self, ctx: SourceContext): + def fetch(self, ctx: SourceContext): + if self.done: + return with open(self.filename, "r") as f: line = f.readline() while line != "": ctx.collect(line[:-1]) line = f.readline() + self.done = True class SimpleMapFunction(MapFunction): diff --git a/streaming/python/includes/libstreaming.pxd b/streaming/python/includes/libstreaming.pxd index 08a1ce129..899e51694 100644 --- a/streaming/python/includes/libstreaming.pxd +++ b/streaming/python/includes/libstreaming.pxd @@ -11,6 +11,9 @@ from libcpp.vector cimport vector as c_vector from libcpp.list cimport list as c_list from cpython cimport PyObject cimport cpython +from libcpp.unordered_map cimport unordered_map as c_unordered_map +from cython.operator cimport dereference, postincrement + cdef inline object PyObject_to_object(PyObject* o): # Cast to "object" increments reference count @@ -32,7 +35,7 @@ from ray.includes.unique_ids cimport ( CObjectID, ) -cdef extern from "status.h" namespace "ray::streaming" nogil: +cdef extern from "common/status.h" namespace "ray::streaming" nogil: cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus": pass cdef CStreamingStatus StatusOK "ray::streaming::StreamingStatus::OK" @@ -70,9 +73,21 @@ cdef extern from "message/message.h" namespace "ray::streaming" nogil: cdef CStreamingMessageType MessageTypeMessage "ray::streaming::StreamingMessageType::Message" cdef cppclass CStreamingMessage "ray::streaming::StreamingMessage": inline uint8_t *RawData() const + inline uint8_t *Payload() const + inline uint32_t PayloadSize() const inline uint32_t GetDataSize() const inline CStreamingMessageType GetMessageType() const - inline uint64_t GetMessageSeqId() const + inline uint64_t GetMessageId() const + @staticmethod + inline void GetBarrierIdFromRawData(const uint8_t *data, + CStreamingBarrierHeader *barrier_header) + cdef struct CStreamingBarrierHeader "ray::streaming::StreamingBarrierHeader": + CStreamingBarrierType barrier_type; + uint64_t barrier_id; + cdef cppclass CStreamingBarrierType "ray::streaming::StreamingBarrierType": + pass + cdef uint32_t kMessageHeaderSize; + cdef uint32_t kBarrierHeaderSize; cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: cdef cppclass CStreamingMessageBundleType "ray::streaming::StreamingMessageBundleType": @@ -97,13 +112,40 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums, c_list[shared_ptr[CStreamingMessage]] &msg_list); -cdef extern from "channel.h" namespace "ray::streaming" nogil: +cdef extern from "channel/channel.h" namespace "ray::streaming" nogil: cdef struct CChannelCreationParameter "ray::streaming::ChannelCreationParameter": CChannelCreationParameter() CActorID actor_id; shared_ptr[CRayFunction] async_function; shared_ptr[CRayFunction] sync_function; + cdef struct CStreamingQueueInfo "ray::streaming::StreamingQueueInfo": + uint64_t first_seq_id; + uint64_t last_message_id; + uint64_t target_message_id; + uint64_t consumed_message_id; + + cdef struct CConsumerChannelInfo "ray::streaming::ConsumerChannelInfo": + CObjectID channel_id; + uint64_t current_message_id; + uint64_t barrier_id; + uint64_t partial_barrier_id; + CStreamingQueueInfo queue_info; + uint64_t last_queue_item_delay; + uint64_t last_queue_item_latency; + uint64_t last_queue_target_diff; + uint64_t get_queue_item_times; + uint64_t notify_cnt; + CChannelCreationParameter parameter; + + cdef enum CTransferCreationStatus "ray::streaming::TransferCreationStatus": + FreshStarted = 0 + PullOk = 1 + Timeout = 2 + DataLost = 3 + Invalid = 999 + + cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil: cdef cppclass CReaderClient "ray::streaming::ReaderClient": CReaderClient() @@ -128,11 +170,12 @@ cdef extern from "data_reader.h" namespace "ray::streaming" nogil: CDataReader(shared_ptr[CRuntimeContext] &runtime_context) void Init(const c_vector[CObjectID] &input_ids, const c_vector[CChannelCreationParameter] ¶ms, - const c_vector[uint64_t] &seq_ids, const c_vector[uint64_t] &msg_ids, + c_vector[CTransferCreationStatus] &creation_status, int64_t timer_interval); CStreamingStatus GetBundle(const uint32_t timeout_ms, shared_ptr[CDataBundle] &message) + void GetOffsetInfo(c_unordered_map[CObjectID, CConsumerChannelInfo] *&offset_map); void Stop() @@ -145,6 +188,9 @@ cdef extern from "data_writer.h" namespace "ray::streaming" nogil: const c_vector[uint64_t] &queue_size_vec); long WriteMessageToBufferRing( const CObjectID &q_id, uint8_t *data, uint32_t data_size) + void BroadcastBarrier(uint64_t checkpoint_id, const uint8_t *data, uint32_t data_size) + void GetChannelOffset(c_vector[uint64_t] &result) + void ClearCheckpoint(uint64_t checkpoint_id) void Run() void Stop() diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi index 8061ad547..4952fd8b5 100644 --- a/streaming/python/includes/transfer.pxi +++ b/streaming/python/includes/transfer.pxi @@ -6,6 +6,8 @@ from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast from libcpp.string cimport string as c_string from libcpp.vector cimport vector as c_vector from libcpp.list cimport list as c_list +from libcpp.unordered_map cimport unordered_map as c_unordered_map +from cython.operator cimport dereference, postincrement from ray.includes.common cimport ( CRayFunction, @@ -38,6 +40,10 @@ from ray.streaming.includes.libstreaming cimport ( CWriterClient, CLocalMemoryBuffer, CChannelCreationParameter, + CTransferCreationStatus, + CConsumerChannelInfo, + CStreamingBarrierHeader, + kBarrierHeaderSize, ) from ray._raylet import JavaFunctionDescriptor @@ -191,7 +197,7 @@ cdef class DataWriter: self.writer = NULL def write(self, ObjectRef qid, const unsigned char[:] value): - """support zero-copy bytes, bytearray, array of unsigned char""" + """support zero-copy bytes, byte array, array of unsigned char""" cdef: CObjectID native_id = qid.data uint64_t msg_id @@ -201,6 +207,25 @@ cdef class DataWriter: msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size) return msg_id + def broadcast_barrier(self, uint64_t checkpoint_id, const unsigned char[:] value): + cdef: + uint8_t *data = (&value[0]) + uint32_t size = value.nbytes + with nogil: + self.writer.BroadcastBarrier(checkpoint_id, data, size) + + def get_output_checkpoints(self): + cdef: + c_vector[uint64_t] results + self.writer.GetChannelOffset(results) + return results + + def clear_checkpoint(self, checkpoint_id): + cdef: + uint64_t c_checkpoint_id = checkpoint_id + with nogil: + self.writer.ClearCheckpoint(c_checkpoint_id) + def stop(self): self.writer.Stop() channel_logger.info("stopped DataWriter") @@ -218,25 +243,22 @@ cdef class DataReader: @staticmethod def create(list py_input_queues, list input_creation_parameters: list[ChannelCreationParameter], - list py_seq_ids, list py_msg_ids, int64_t timer_interval, - c_bool is_recreate, bytes config_bytes, c_bool is_mock): cdef: c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues) c_vector[CChannelCreationParameter] initial_parameters - c_vector[uint64_t] seq_ids c_vector[uint64_t] msg_ids + c_vector[CTransferCreationStatus] c_creation_status CDataReader *c_reader ChannelCreationParameter parameter cdef const unsigned char[:] config_data for param in input_creation_parameters: parameter = param initial_parameters.push_back(parameter.get_parameter()) - for py_seq_id in py_seq_ids: - seq_ids.push_back(py_seq_id) + for py_msg_id in py_msg_ids: msg_ids.push_back(py_msg_id) cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() @@ -247,11 +269,19 @@ cdef class DataReader: if is_mock: ctx.get().MarkMockTest() c_reader = new CDataReader(ctx) - c_reader.Init(queue_id_vec, initial_parameters, seq_ids, msg_ids, timer_interval) + c_reader.Init(queue_id_vec, initial_parameters, msg_ids, c_creation_status, timer_interval) + + creation_status_map = {} + if not c_creation_status.empty(): + for i in range(queue_id_vec.size()): + k = queue_id_vec[i].Binary() + v = c_creation_status[i] + creation_status_map[k] = v + channel_logger.info("create native reader succeed") cdef DataReader reader = DataReader.__new__(DataReader) reader.reader = c_reader - return reader + return reader, creation_status_map def __dealloc__(self): if self.reader != NULL: @@ -265,23 +295,33 @@ cdef class DataReader: CStreamingStatus status with nogil: status = self.reader.GetBundle(timeout_millis, bundle) - cdef uint32_t bundle_type = (bundle.get().meta.get().GetBundleType()) if status != libstreaming.StatusOK: if status == libstreaming.StatusInterrupted: # avoid cyclic import import ray.streaming.runtime.transfer as transfer raise transfer.ChannelInterruptException("reader interrupted") elif status == libstreaming.StatusInitQueueFailed: - raise Exception("init channel failed") - elif status == libstreaming.StatusWaitQueueTimeOut: - raise Exception("wait channel object timeout") + import ray.streaming.runtime.transfer as transfer + raise transfer.ChannelInitException("init channel failed") + elif status == libstreaming.StatusGetBundleTimeOut: + return [] + else: + raise Exception("no such status " + str(status)) cdef: uint32_t msg_nums - CObjectID queue_id + CObjectID queue_id = bundle.get().c_from c_list[shared_ptr[CStreamingMessage]] msg_list list msgs = [] uint64_t timestamp uint64_t msg_id + c_unordered_map[CObjectID, CConsumerChannelInfo] *offset_map = NULL + shared_ptr[CStreamingMessage] barrier + CStreamingBarrierHeader barrier_header + c_unordered_map[CObjectID, CConsumerChannelInfo].iterator it + + cdef uint32_t bundle_type = (bundle.get().meta.get().GetBundleType()) + # avoid cyclic import + from ray.streaming.runtime.transfer import DataMessage if bundle_type == libstreaming.BundleTypeBundle: msg_nums = bundle.get().meta.get().GetMessageListSize() CStreamingMessageBundle.GetMessageListFromRawData( @@ -291,16 +331,48 @@ cdef class DataReader: msg_list) timestamp = bundle.get().meta.get().GetMessageBundleTs() for msg in msg_list: - msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()] + msg_bytes = msg.get().Payload()[:msg.get().PayloadSize()] qid_bytes = queue_id.Binary() - msg_id = msg.get().GetMessageSeqId() - msgs.append((msg_bytes, msg_id, timestamp, qid_bytes)) + msg_id = msg.get().GetMessageId() + msgs.append( + DataMessage(msg_bytes, timestamp, msg_id, qid_bytes)) return msgs elif bundle_type == libstreaming.BundleTypeEmpty: - return [] + timestamp = bundle.get().meta.get().GetMessageBundleTs() + msg_id = bundle.get().meta.get().GetLastMessageId() + return [DataMessage(None, timestamp, msg_id, queue_id.Binary(), True)] + elif bundle.get().meta.get().IsBarrier(): + py_offset_map = {} + self.reader.GetOffsetInfo(offset_map) + it = offset_map.begin() + while it != offset_map.end(): + queue_id_bytes = dereference(it).first.Binary() + current_message_id = dereference(it).second.current_message_id + py_offset_map[queue_id_bytes] = current_message_id + postincrement(it) + msg_nums = bundle.get().meta.get().GetMessageListSize() + CStreamingMessageBundle.GetMessageListFromRawData( + bundle.get().data + libstreaming.kMessageBundleHeaderSize, + bundle.get().data_size - libstreaming.kMessageBundleHeaderSize, + msg_nums, + msg_list) + timestamp = bundle.get().meta.get().GetMessageBundleTs() + barrier = msg_list.front() + msg_id = barrier.get().GetMessageId() + CStreamingMessage.GetBarrierIdFromRawData(barrier.get().Payload(), &barrier_header) + barrier_id = barrier_header.barrier_id + barrier_data = (barrier.get().Payload() + kBarrierHeaderSize)[ + :barrier.get().PayloadSize() - kBarrierHeaderSize] + barrier_type = barrier_header.barrier_type + py_queue_id = queue_id.Binary() + from ray.streaming.runtime.transfer import CheckpointBarrier + return [CheckpointBarrier( + barrier_data, timestamp, msg_id, py_queue_id, py_offset_map, + barrier_id, barrier_type)] else: raise Exception("Unsupported bundle type {}".format(bundle_type)) + def stop(self): self.reader.Stop() channel_logger.info("stopped DataReader") diff --git a/streaming/python/operator.py b/streaming/python/operator.py index 4952b2d00..9163519d6 100644 --- a/streaming/python/operator.py +++ b/streaming/python/operator.py @@ -3,10 +3,11 @@ import importlib import logging from abc import ABC, abstractmethod -from ray import streaming from ray.streaming import function from ray.streaming import message from ray.streaming.collector import Collector +from ray.streaming.collector import CollectionCollector +from ray.streaming.function import SourceFunction from ray.streaming.runtime import gateway_client logger = logging.getLogger(__name__) @@ -40,6 +41,14 @@ class Operator(ABC): def operator_type(self) -> OperatorType: pass + @abstractmethod + def save_checkpoint(self): + pass + + @abstractmethod + def load_checkpoint(self, checkpoint_obj): + pass + class OneInputOperator(Operator, ABC): """Interface for stream operators with one input.""" @@ -90,8 +99,20 @@ class StreamOperator(Operator, ABC): for collector in self.collectors: collector.collect(record) + def save_checkpoint(self): + self.func.save_checkpoint() -class SourceOperator(StreamOperator): + def load_checkpoint(self, checkpoint_obj): + self.func.load_checkpoint(checkpoint_obj) + + +class SourceOperator(Operator, ABC): + @abstractmethod + def fetch(self): + pass + + +class SourceOperatorImpl(SourceOperator, StreamOperator): """ Operator to run a :class:`function.SourceFunction` """ @@ -104,19 +125,19 @@ class SourceOperator(StreamOperator): for collector in self.collectors: collector.collect(message.Record(value)) - def __init__(self, func): + def __init__(self, func: SourceFunction): assert isinstance(func, function.SourceFunction) super().__init__(func) self.source_context = None def open(self, collectors, runtime_context): super().open(collectors, runtime_context) - self.source_context = SourceOperator.SourceContextImpl(collectors) + self.source_context = SourceOperatorImpl.SourceContextImpl(collectors) self.func.init(runtime_context.get_parallelism(), runtime_context.get_task_index()) - def run(self): - self.func.run(self.source_context) + def fetch(self): + self.func.fetch(self.source_context) def operator_type(self): return OperatorType.SOURCE @@ -147,8 +168,7 @@ class FlatMapOperator(StreamOperator, OneInputOperator): def open(self, collectors, runtime_context): super().open(collectors, runtime_context) - self.collection_collector = streaming.collector.CollectionCollector( - collectors) + self.collection_collector = CollectionCollector(collectors) def process_element(self, record): self.func.flat_map(record.value, self.collection_collector) @@ -286,12 +306,12 @@ class ChainedOperator(StreamOperator, ABC): raise Exception("Current operator type is not supported") -class ChainedSourceOperator(ChainedOperator): +class ChainedSourceOperator(SourceOperator, ChainedOperator): def __init__(self, operators, configs): super().__init__(operators, configs) - def run(self): - self.operators[0].run() + def fetch(self): + self.operators[0].fetch() class ChainedOneInputOperator(ChainedOperator): @@ -350,7 +370,7 @@ def load_operator(descriptor_operator_bytes: bytes): _function_to_operator = { - function.SourceFunction: SourceOperator, + function.SourceFunction: SourceOperatorImpl, function.MapFunction: MapOperator, function.FlatMapFunction: FlatMapOperator, function.FilterFunction: FilterOperator, diff --git a/streaming/python/runtime/command.py b/streaming/python/runtime/command.py new file mode 100644 index 000000000..cc5f02e1f --- /dev/null +++ b/streaming/python/runtime/command.py @@ -0,0 +1,30 @@ +class BaseWorkerCmd: + """ + base worker cmd + """ + + def __init__(self, actor_id): + self.from_actor_id = actor_id + + +class WorkerCommitReport(BaseWorkerCmd): + """ + worker commit report + """ + + def __init__(self, actor_id, commit_checkpoint_id): + super().__init__(actor_id) + self.commit_checkpoint_id = commit_checkpoint_id + + +class WorkerRollbackRequest(BaseWorkerCmd): + """ + worker rollback request + """ + + def __init__(self, actor_id, exception_msg): + super().__init__(actor_id) + self.__exception_msg = exception_msg + + def exception_msg(self): + return self.__exception_msg diff --git a/streaming/python/runtime/context_backend.py b/streaming/python/runtime/context_backend.py new file mode 100644 index 000000000..65e811cfe --- /dev/null +++ b/streaming/python/runtime/context_backend.py @@ -0,0 +1,117 @@ +import logging +import os +from abc import ABC, abstractmethod +from os import path + +from ray.streaming.config import ConfigHelper, Config + +logger = logging.getLogger(__name__) + + +class ContextBackend(ABC): + @abstractmethod + def get(self, key): + pass + + @abstractmethod + def put(self, key, value): + pass + + @abstractmethod + def remove(self, key): + pass + + +class MemoryContextBackend(ContextBackend): + def __init__(self, conf): + self.__dic = dict() + + def get(self, key): + return self.__dic.get(key) + + def put(self, key, value): + self.__dic[key] = value + + def remove(self, key): + if key in self.__dic: + del self.__dic[key] + + +class LocalFileContextBackend(ContextBackend): + def __init__(self, conf): + self.__dir = ConfigHelper.get_cp_local_file_root_dir(conf) + logger.info("Start init local file state backend, root_dir={}.".format( + self.__dir)) + try: + os.mkdir(self.__dir) + except FileExistsError: + logger.info("dir already exists, skipped.") + + def put(self, key, value): + logger.info("Put value of key {} start.".format(key)) + with open(self.__gen_file_path(key), "wb") as f: + f.write(value) + + def get(self, key): + logger.info("Get value of key {} start.".format(key)) + full_path = self.__gen_file_path(key) + if not os.path.isfile(full_path): + return None + with open(full_path, "rb") as f: + return f.read() + + def remove(self, key): + logger.info("Remove value of key {} start.".format(key)) + try: + os.remove(self.__gen_file_path(key)) + except Exception: + # ignore exception + pass + + def rename(self, src, dst): + logger.info("rename {} to {}".format(src, dst)) + os.rename(self.__gen_file_path(src), self.__gen_file_path(dst)) + + def exists(self, key) -> bool: + return os.path.exists(key) + + def __gen_file_path(self, key): + return path.join(self.__dir, key) + + +class AtomicFsContextBackend(LocalFileContextBackend): + def __init__(self, conf): + super().__init__(conf) + self.__tmp_flag = "_tmp" + + def put(self, key, value): + tmp_key = key + self.__tmp_flag + if super().exists(tmp_key) and not super().exists(key): + super().rename(tmp_key, key) + super().put(tmp_key, value) + super().remove(key) + super().rename(tmp_key, key) + + def get(self, key): + tmp_key = key + self.__tmp_flag + if super().exists(tmp_key) and not super().exists(key): + return super().get(tmp_key) + return super().get(key) + + def remove(self, key): + tmp_key = key + self.__tmp_flag + if super().exists(tmp_key): + super().remove(tmp_key) + super().remove(key) + + +class ContextBackendFactory: + @staticmethod + def get_context_backend(worker_config) -> ContextBackend: + backend_type = ConfigHelper.get_cp_context_backend_type(worker_config) + context_backend = None + if backend_type == Config.CP_STATE_BACKEND_LOCAL_FILE: + context_backend = AtomicFsContextBackend(worker_config) + elif backend_type == Config.CP_STATE_BACKEND_MEMORY: + context_backend = MemoryContextBackend(worker_config) + return context_backend diff --git a/streaming/python/runtime/failover.py b/streaming/python/runtime/failover.py new file mode 100644 index 000000000..702cdbab3 --- /dev/null +++ b/streaming/python/runtime/failover.py @@ -0,0 +1,30 @@ +class Barrier: + """ + barrier + """ + + def __init__(self, id): + self.id = id + + def __str__(self): + return "Barrier [id:%s]" % self.id + + +class OpCheckpointInfo: + """ + operator checkpoint info + """ + + def __init__(self, + operator_point=None, + input_points=None, + output_points=None, + checkpoint_id=None): + if input_points is None: + input_points = {} + if output_points is None: + output_points = {} + self.operator_point = operator_point + self.input_points = input_points + self.output_points = output_points + self.checkpoint_id = checkpoint_id diff --git a/streaming/python/runtime/graph.py b/streaming/python/runtime/graph.py index 6db9cf39e..d2719ee12 100644 --- a/streaming/python/runtime/graph.py +++ b/streaming/python/runtime/graph.py @@ -5,6 +5,9 @@ import ray import ray.streaming.generated.remote_call_pb2 as remote_call_pb import ray.streaming.operator as operator import ray.streaming.partition as partition +from ray._raylet import ActorID +from ray.actor import ActorHandle +from ray.streaming.config import Config from ray.streaming.generated.streaming_pb2 import Language logger = logging.getLogger(__name__) @@ -27,10 +30,12 @@ class NodeType(enum.Enum): class ExecutionEdge: - def __init__(self, edge_pb, language): - self.source_execution_vertex_id = edge_pb.source_execution_vertex_id - self.target_execution_vertex_id = edge_pb.target_execution_vertex_id - partition_bytes = edge_pb.partition + def __init__(self, execution_edge_pb, language): + self.source_execution_vertex_id = execution_edge_pb \ + .source_execution_vertex_id + self.target_execution_vertex_id = execution_edge_pb \ + .target_execution_vertex_id + partition_bytes = execution_edge_pb.partition # Sink node doesn't have partition function, # so we only deserialize partition_bytes when it's not None or empty if language == Language.PYTHON and partition_bytes: @@ -38,50 +43,73 @@ class ExecutionEdge: class ExecutionVertex: - def __init__(self, vertex_pb): - self.execution_vertex_id = vertex_pb.execution_vertex_id - self.execution_job_vertex_Id = vertex_pb.execution_job_vertex_Id - self.execution_job_vertex_name = vertex_pb.execution_job_vertex_name - self.execution_vertex_index = vertex_pb.execution_vertex_index - self.parallelism = vertex_pb.parallelism - if vertex_pb.language == Language.PYTHON: - operator_bytes = vertex_pb.operator # python operator descriptor - if vertex_pb.chained: + worker_actor: ActorHandle + + def __init__(self, execution_vertex_pb): + self.execution_vertex_id = execution_vertex_pb.execution_vertex_id + self.execution_job_vertex_id = execution_vertex_pb \ + .execution_job_vertex_id + self.execution_job_vertex_name = execution_vertex_pb \ + .execution_job_vertex_name + self.execution_vertex_index = execution_vertex_pb\ + .execution_vertex_index + self.parallelism = execution_vertex_pb.parallelism + if execution_vertex_pb\ + .language == Language.PYTHON: + # python operator descriptor + operator_bytes = execution_vertex_pb.operator + if execution_vertex_pb.chained: logger.info("Load chained operator") self.stream_operator = operator.load_chained_operator( operator_bytes) else: logger.info("Load operator") self.stream_operator = operator.load_operator(operator_bytes) - self.worker_actor = ray.actor.ActorHandle. \ - _deserialization_helper(vertex_pb.worker_actor) - self.container_id = vertex_pb.container_id - self.build_time = vertex_pb.build_time - self.language = vertex_pb.language - self.config = vertex_pb.config - self.resource = vertex_pb.resource + self.worker_actor = None + if execution_vertex_pb.worker_actor: + self.worker_actor = ray.actor.ActorHandle. \ + _deserialization_helper(execution_vertex_pb.worker_actor) + self.container_id = execution_vertex_pb.container_id + self.build_time = execution_vertex_pb.build_time + self.language = execution_vertex_pb.language + self.config = execution_vertex_pb.config + self.resource = execution_vertex_pb.resource + + @property + def execution_vertex_name(self): + return "{}_{}_{}".format(self.execution_job_vertex_id, + self.execution_job_vertex_name, + self.execution_vertex_id) class ExecutionVertexContext: - def __init__(self, - vertex_context_pb: remote_call_pb.ExecutionVertexContext): - self.execution_vertex = \ - ExecutionVertex(vertex_context_pb.current_execution_vertex) + actor_id: ActorID + execution_vertex: ExecutionVertex + + def __init__( + self, + execution_vertex_context_pb: remote_call_pb.ExecutionVertexContext + ): + self.execution_vertex = ExecutionVertex( + execution_vertex_context_pb.current_execution_vertex) + self.job_name = self.execution_vertex.config[Config.STREAMING_JOB_NAME] + self.exe_vertex_name = self.execution_vertex.execution_vertex_name + self.actor_id = self.execution_vertex.worker_actor._ray_actor_id self.upstream_execution_vertices = [ - ExecutionVertex(vertex) - for vertex in vertex_context_pb.upstream_execution_vertices + ExecutionVertex(vertex) for vertex in + execution_vertex_context_pb.upstream_execution_vertices ] self.downstream_execution_vertices = [ - ExecutionVertex(vertex) - for vertex in vertex_context_pb.downstream_execution_vertices + ExecutionVertex(vertex) for vertex in + execution_vertex_context_pb.downstream_execution_vertices ] self.input_execution_edges = [ ExecutionEdge(edge, self.execution_vertex.language) - for edge in vertex_context_pb.input_execution_edges + for edge in execution_vertex_context_pb.input_execution_edges ] self.output_execution_edges = [ ExecutionEdge(edge, self.execution_vertex.language) - for edge in vertex_context_pb.output_execution_edges + for edge in execution_vertex_context_pb.output_execution_edges ] def get_parallelism(self): @@ -112,16 +140,16 @@ class ExecutionVertexContext: def get_task_id(self): return self.execution_vertex.execution_vertex_id - def get_source_actor_by_vertex_id(self, execution_vertex_id): - for vertex in self.upstream_execution_vertices: - if vertex.execution_vertex_id == execution_vertex_id: - return vertex.worker_actor - raise Exception("ExecutionVertex %s does not exist!" - .format(execution_vertex_id)) + def get_source_actor_by_execution_vertex_id(self, execution_vertex_id): + for execution_vertex in self.upstream_execution_vertices: + if execution_vertex.execution_vertex_id == execution_vertex_id: + return execution_vertex.worker_actor + raise Exception( + "Vertex %s does not exist!".format(execution_vertex_id)) - def get_target_actor_by_vertex_id(self, execution_vertex_id): - for vertex in self.downstream_execution_vertices: - if vertex.execution_vertex_id == execution_vertex_id: - return vertex.worker_actor - raise Exception("ExecutionVertex %s does not exist!" - .format(execution_vertex_id)) + def get_target_actor_by_execution_vertex_id(self, execution_vertex_id): + for execution_vertex in self.downstream_execution_vertices: + if execution_vertex.execution_vertex_id == execution_vertex_id: + return execution_vertex.worker_actor + raise Exception( + "Vertex %s does not exist!".format(execution_vertex_id)) diff --git a/streaming/python/runtime/processor.py b/streaming/python/runtime/processor.py index ccfa55921..1083713ee 100644 --- a/streaming/python/runtime/processor.py +++ b/streaming/python/runtime/processor.py @@ -23,6 +23,14 @@ class Processor(ABC): def close(self): pass + @abstractmethod + def save_checkpoint(self): + pass + + @abstractmethod + def load_checkpoint(self, checkpoint_obj): + pass + class StreamingProcessor(Processor, ABC): """StreamingProcessor is a process unit for a operator.""" @@ -40,7 +48,13 @@ class StreamingProcessor(Processor, ABC): logger.info("Opened Processor {}".format(self)) def close(self): - pass + self.operator.close() + + def save_checkpoint(self): + self.operator.save_checkpoint() + + def load_checkpoint(self, checkpoint_obj): + self.operator.load_checkpoint(checkpoint_obj) class SourceProcessor(StreamingProcessor): @@ -52,8 +66,8 @@ class SourceProcessor(StreamingProcessor): def process(self, record): raise Exception("SourceProcessor should not process record") - def run(self): - self.operator.run() + def fetch(self): + self.operator.fetch() class OneInputProcessor(StreamingProcessor): diff --git a/streaming/python/runtime/remote_call.py b/streaming/python/runtime/remote_call.py new file mode 100644 index 000000000..4f5f082ee --- /dev/null +++ b/streaming/python/runtime/remote_call.py @@ -0,0 +1,95 @@ +import logging +import os +import ray +import time +from enum import Enum + +from ray.actor import ActorHandle +from ray.streaming.generated import remote_call_pb2 +from ray.streaming.runtime.command\ + import WorkerCommitReport, WorkerRollbackRequest + +logger = logging.getLogger(__name__) + + +class CallResult: + """ + Call Result + """ + + def __init__(self, success, result_code, result_msg, result_obj): + self.success = success + self.result_code = result_code + self.result_msg = result_msg + self.result_obj = result_obj + + @staticmethod + def success(payload=None): + return CallResult(True, CallResultEnum.SUCCESS, None, payload) + + @staticmethod + def fail(payload=None): + return CallResult(False, CallResultEnum.FAILED, None, payload) + + @staticmethod + def skipped(msg=None): + return CallResult(True, CallResultEnum.SKIPPED, msg, None) + + def is_success(self): + if self.result_code is CallResultEnum.SUCCESS: + return True + + return False + + +class CallResultEnum(Enum): + """ + call result enum + """ + + SUCCESS = 0 + FAILED = 1 + SKIPPED = 2 + + +class RemoteCallMst: + """ + remote call job master + """ + + @staticmethod + def request_job_worker_rollback(master: ActorHandle, + request: WorkerRollbackRequest): + logger.info("Remote call mst: request job worker rollback start.") + request_pb = remote_call_pb2.BaseWorkerCmd() + request_pb.actor_id = request.from_actor_id + request_pb.timestamp = int(time.time() * 1000.0) + rollback_request_pb = remote_call_pb2.WorkerRollbackRequest() + rollback_request_pb.exception_msg = request.exception_msg() + rollback_request_pb.worker_hostname = os.uname()[1] + rollback_request_pb.worker_pid = str(os.getpid()) + request_pb.detail.Pack(rollback_request_pb) + return_ids = master.requestJobWorkerRollback\ + .remote(request_pb.SerializeToString()) + result = remote_call_pb2.BoolResult() + result.ParseFromString(ray.get(return_ids)) + logger.info("Remote call mst: request job worker rollback finish.") + return result.boolRes + + @staticmethod + def report_job_worker_commit(master: ActorHandle, + report: WorkerCommitReport): + logger.info("Remote call mst: report job worker commit start.") + report_pb = remote_call_pb2.BaseWorkerCmd() + + report_pb.actor_id = report.from_actor_id + report_pb.timestamp = int(time.time() * 1000.0) + wk_commit = remote_call_pb2.WorkerCommitReport() + wk_commit.commit_checkpoint_id = report.commit_checkpoint_id + report_pb.detail.Pack(wk_commit) + return_id = master.reportJobWorkerCommit\ + .remote(report_pb.SerializeToString()) + result = remote_call_pb2.BoolResult() + result.ParseFromString(ray.get(return_id)) + logger.info("Remote call mst: report job worker commit finish.") + return result.boolRes diff --git a/streaming/python/runtime/serialization.py b/streaming/python/runtime/serialization.py index 2d038e482..600e1084c 100644 --- a/streaming/python/runtime/serialization.py +++ b/streaming/python/runtime/serialization.py @@ -3,11 +3,11 @@ import pickle import msgpack from ray.streaming import message -_RECORD_TYPE_ID = 0 -_KEY_RECORD_TYPE_ID = 1 -_CROSS_LANG_TYPE_ID = b"0" -_JAVA_TYPE_ID = b"1" -_PYTHON_TYPE_ID = b"2" +RECORD_TYPE_ID = 0 +KEY_RECORD_TYPE_ID = 1 +CROSS_LANG_TYPE_ID = 0 +JAVA_TYPE_ID = 1 +PYTHON_TYPE_ID = 2 class Serializer(ABC): @@ -33,21 +33,21 @@ class CrossLangSerializer(Serializer): def serialize(self, obj): if type(obj) is message.Record: - fields = [_RECORD_TYPE_ID, obj.stream, obj.value] + fields = [RECORD_TYPE_ID, obj.stream, obj.value] elif type(obj) is message.KeyRecord: - fields = [_KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value] + fields = [KEY_RECORD_TYPE_ID, obj.stream, obj.key, obj.value] else: raise Exception("Unsupported value {}".format(obj)) return msgpack.packb(fields, use_bin_type=True) def deserialize(self, data): - fields = msgpack.unpackb(data, raw=False, strict_map_key=False) - if fields[0] == _RECORD_TYPE_ID: + fields = msgpack.unpackb(data, raw=False) + if fields[0] == RECORD_TYPE_ID: stream, value = fields[1:] record = message.Record(value) record.stream = stream return record - elif fields[0] == _KEY_RECORD_TYPE_ID: + elif fields[0] == KEY_RECORD_TYPE_ID: stream, key, value = fields[1:] key_record = message.KeyRecord(key, value) key_record.stream = stream diff --git a/streaming/python/runtime/task.py b/streaming/python/runtime/task.py index 713a29703..54ec3cf38 100644 --- a/streaming/python/runtime/task.py +++ b/streaming/python/runtime/task.py @@ -1,14 +1,30 @@ import logging +import pickle import threading +import time +import typing from abc import ABC, abstractmethod +from typing import Optional from ray.streaming.collector import OutputCollector from ray.streaming.config import Config from ray.streaming.context import RuntimeContextImpl +from ray.streaming.generated import remote_call_pb2 from ray.streaming.runtime import serialization +from ray.streaming.runtime.command import WorkerCommitReport +from ray.streaming.runtime.failover import Barrier, OpCheckpointInfo +from ray.streaming.runtime.remote_call import RemoteCallMst from ray.streaming.runtime.serialization import \ PythonSerializer, CrossLangSerializer +from ray.streaming.runtime.transfer import CheckpointBarrier +from ray.streaming.runtime.transfer import DataMessage from ray.streaming.runtime.transfer import ChannelID, DataWriter, DataReader +from ray.streaming.runtime.transfer import ChannelRecoverInfo +from ray.streaming.runtime.transfer import ChannelInterruptException + +if typing.TYPE_CHECKING: + from ray.streaming.runtime.worker import JobWorker + from ray.streaming.runtime.processor import Processor, SourceProcessor logger = logging.getLogger(__name__) @@ -16,18 +32,85 @@ logger = logging.getLogger(__name__) class StreamTask(ABC): """Base class for all streaming tasks. Each task runs a processor.""" - def __init__(self, task_id, processor, worker): + def __init__(self, task_id: int, processor: "Processor", + worker: "JobWorker", last_checkpoint_id: int): + self.worker_context = worker.worker_context + self.vertex_context = worker.execution_vertex_context self.task_id = task_id self.processor = processor self.worker = worker - self.config = worker.config - self.reader = None # DataReader - self.writers = {} # ExecutionEdge -> DataWriter - self.thread = None - self.prepare_task() + self.config: dict = worker.config + self.reader: Optional[DataReader] = None + self.writer: Optional[DataWriter] = None + self.is_initial_state = True + self.last_checkpoint_id: int = last_checkpoint_id self.thread = threading.Thread(target=self.run, daemon=True) - def prepare_task(self): + def do_checkpoint(self, checkpoint_id: int, input_points): + logger.info("Start do checkpoint, cp id {}, inputPoints {}.".format( + checkpoint_id, input_points)) + + output_points = None + if self.writer is not None: + output_points = self.writer.get_output_checkpoints() + + operator_checkpoint = self.processor.save_checkpoint() + op_checkpoint_info = OpCheckpointInfo( + operator_checkpoint, input_points, output_points, checkpoint_id) + self.__save_cp_state_and_report(op_checkpoint_info, checkpoint_id) + + barrier_pb = remote_call_pb2.Barrier() + barrier_pb.id = checkpoint_id + byte_buffer = barrier_pb.SerializeToString() + if self.writer is not None: + self.writer.broadcast_barrier(checkpoint_id, byte_buffer) + logger.info("Operator checkpoint {} finish.".format(checkpoint_id)) + + def __save_cp_state_and_report(self, op_checkpoint_info, checkpoint_id): + logger.info( + "Start to save cp state and report, checkpoint id is {}.".format( + checkpoint_id)) + self.__save_cp(op_checkpoint_info, checkpoint_id) + self.__report_commit(checkpoint_id) + self.last_checkpoint_id = checkpoint_id + + def __save_cp(self, op_checkpoint_info, checkpoint_id): + logger.info("save operator cp, op_checkpoint_info={}".format( + op_checkpoint_info)) + cp_bytes = pickle.dumps(op_checkpoint_info) + self.worker.context_backend.put( + self.__gen_op_checkpoint_key(checkpoint_id), cp_bytes) + + def __report_commit(self, checkpoint_id: int): + logger.info("Report commit, checkpoint id {}.".format(checkpoint_id)) + report = WorkerCommitReport(self.vertex_context.actor_id.binary(), + checkpoint_id) + RemoteCallMst.report_job_worker_commit(self.worker.master_actor, + report) + + def clear_expired_cp_state(self, checkpoint_id): + cp_key = self.__gen_op_checkpoint_key(checkpoint_id) + self.worker.context_backend.remove(cp_key) + + def clear_expired_queue_msg(self, checkpoint_id): + # clear operator checkpoint + if self.writer is not None: + self.writer.clear_checkpoint(checkpoint_id) + + def request_rollback(self, exception_msg: str): + self.worker.request_rollback(exception_msg) + + def __gen_op_checkpoint_key(self, checkpoint_id): + op_checkpoint_key = Config.JOB_WORKER_OP_CHECKPOINT_PREFIX_KEY + str( + self.vertex_context.job_name) + "_" + str( + self.vertex_context.exe_vertex_name) + "_" + str(checkpoint_id) + logger.info( + "Generate op checkpoint key {}. ".format(op_checkpoint_key)) + return op_checkpoint_key + + def prepare_task(self, is_recreate: bool): + logger.info( + "Preparing stream task, is_recreate={}.".format(is_recreate)) channel_conf = dict(self.worker.config) channel_size = int( self.worker.config.get(Config.CHANNEL_SIZE, @@ -39,45 +122,76 @@ class StreamTask(ABC): execution_vertex_context = self.worker.execution_vertex_context build_time = execution_vertex_context.build_time + # when use memory state, if actor throw exception, will miss state + op_checkpoint_info = OpCheckpointInfo() + + cp_bytes = None + # get operator checkpoint + if is_recreate: + cp_key = self.__gen_op_checkpoint_key(self.last_checkpoint_id) + logger.info("Getting task checkpoints from state, " + "cpKey={}, checkpointId={}.".format( + cp_key, self.last_checkpoint_id)) + cp_bytes = self.worker.context_backend.get(cp_key) + if cp_bytes is None: + msg = "Task recover failed, checkpoint is null!"\ + "cpKey={}".format(cp_key) + raise RuntimeError(msg) + + if cp_bytes is not None: + op_checkpoint_info = pickle.loads(cp_bytes) + self.processor.load_checkpoint(op_checkpoint_info.operator_point) + logger.info("Stream task recover from checkpoint state," + "checkpoint bytes len={}, checkpointInfo={}.".format( + cp_bytes.__len__(), op_checkpoint_info)) + # writers collectors = [] output_actors_map = {} for edge in execution_vertex_context.output_execution_edges: target_task_id = edge.target_execution_vertex_id - target_actor = execution_vertex_context\ - .get_target_actor_by_vertex_id(target_task_id) + target_actor = execution_vertex_context \ + .get_target_actor_by_execution_vertex_id(target_task_id) channel_name = ChannelID.gen_id(self.task_id, target_task_id, build_time) output_actors_map[channel_name] = target_actor - if len(output_actors_map) > 0: - channel_ids = list(output_actors_map.keys()) - target_actors = list(output_actors_map.values()) - logger.info( - "Create DataWriter channel_ids {}, target_actors {}." - .format(channel_ids, target_actors)) - writer = DataWriter(channel_ids, target_actors, channel_conf) - self.writers[edge] = writer + if len(output_actors_map) > 0: + channel_str_ids = list(output_actors_map.keys()) + target_actors = list(output_actors_map.values()) + logger.info("Create DataWriter channel_ids {}," + "target_actors {}, output_points={}.".format( + channel_str_ids, target_actors, + op_checkpoint_info.output_points)) + self.writer = DataWriter(channel_str_ids, target_actors, + channel_conf) + logger.info("Create DataWriter succeed channel_ids {}, " + "target_actors {}.".format(channel_str_ids, + target_actors)) + for edge in execution_vertex_context.output_execution_edges: collectors.append( - OutputCollector(writer, channel_ids, target_actors, - edge.partition)) + OutputCollector(self.writer, channel_str_ids, + target_actors, edge.partition)) # readers input_actor_map = {} for edge in execution_vertex_context.input_execution_edges: source_task_id = edge.source_execution_vertex_id - source_actor = execution_vertex_context\ - .get_source_actor_by_vertex_id(source_task_id) + source_actor = execution_vertex_context \ + .get_source_actor_by_execution_vertex_id(source_task_id) channel_name = ChannelID.gen_id(source_task_id, self.task_id, build_time) input_actor_map[channel_name] = source_actor if len(input_actor_map) > 0: - channel_ids = list(input_actor_map.keys()) + channel_str_ids = list(input_actor_map.keys()) from_actors = list(input_actor_map.values()) - logger.info("Create DataReader, channels {}, input_actors {}." - .format(channel_ids, from_actors)) - self.reader = DataReader(channel_ids, from_actors, channel_conf) + logger.info("Create DataReader, channels {}," + "input_actors {}, input_points={}.".format( + channel_str_ids, from_actors, + op_checkpoint_info.input_points)) + self.reader = DataReader(channel_str_ids, from_actors, + channel_conf) def exit_handler(): # Make DataReader stop read data when MockQueue destructor @@ -87,21 +201,31 @@ class StreamTask(ABC): import atexit atexit.register(exit_handler) - # TODO(chaokunyang) add task/job config runtime_context = RuntimeContextImpl( self.worker.task_id, execution_vertex_context.execution_vertex.execution_vertex_index, - execution_vertex_context.get_parallelism()) + execution_vertex_context.get_parallelism(), + config=channel_conf, + job_config=channel_conf) logger.info("open Processor {}".format(self.processor)) self.processor.open(collectors, runtime_context) - @abstractmethod - def init(self): - pass + # immediately save cp. In case of FO in cp 0 + # or use old cp in multi node FO. + self.__save_cp(op_checkpoint_info, self.last_checkpoint_id) + + def recover(self, is_recreate: bool): + self.prepare_task(is_recreate) + + recover_info = ChannelRecoverInfo() + if self.reader is not None: + recover_info = self.reader.get_channel_recover_info() - def start(self): self.thread.start() + logger.info("Start operator success.") + return recover_info + @abstractmethod def run(self): pass @@ -110,14 +234,24 @@ class StreamTask(ABC): def cancel_task(self): pass + @abstractmethod + def commit_trigger(self, barrier: Barrier) -> bool: + pass + class InputStreamTask(StreamTask): """Base class for stream tasks that execute a :class:`runtime.processor.OneInputProcessor` or :class:`runtime.processor.TwoInputProcessor` """ - def __init__(self, task_id, processor_instance, worker): - super().__init__(task_id, processor_instance, worker) + def commit_trigger(self, barrier): + raise RuntimeError( + "commit_trigger is only supported in SourceStreamTask.") + + def __init__(self, task_id, processor_instance, worker, + last_checkpoint_id): + super().__init__(task_id, processor_instance, worker, + last_checkpoint_id) self.running = True self.stopped = False self.read_timeout_millis = \ @@ -126,25 +260,58 @@ class InputStreamTask(StreamTask): self.python_serializer = PythonSerializer() self.cross_lang_serializer = CrossLangSerializer() - def init(self): - pass - def run(self): - while self.running: - item = self.reader.read(self.read_timeout_millis) - if item is not None: - msg_data = item.body() - type_id = msg_data[:1] - if (type_id == serialization._PYTHON_TYPE_ID): - msg = self.python_serializer.deserialize(msg_data[1:]) + logger.info("Input task thread start.") + try: + while self.running: + self.worker.initial_state_lock.acquire() + try: + item = self.reader.read(self.read_timeout_millis) + self.is_initial_state = False + finally: + self.worker.initial_state_lock.release() + + if item is None: + continue + + if isinstance(item, DataMessage): + msg_data = item.body + type_id = msg_data[0] + if type_id == serialization.PYTHON_TYPE_ID: + msg = self.python_serializer.deserialize(msg_data[1:]) + else: + msg = self.cross_lang_serializer.deserialize( + msg_data[1:]) + self.processor.process(msg) + elif isinstance(item, CheckpointBarrier): + logger.info("Got barrier:{}".format(item)) + logger.info("Start to do checkpoint {}.".format( + item.checkpoint_id)) + + input_points = item.get_input_checkpoints() + + self.do_checkpoint(item.checkpoint_id, input_points) + logger.info("Do checkpoint {} success.".format( + item.checkpoint_id)) else: - msg = self.cross_lang_serializer.deserialize(msg_data[1:]) - self.processor.process(msg) + raise RuntimeError( + "Unknown item type! item={}".format(item)) + + except ChannelInterruptException: + logger.info("queue has stopped.") + except BaseException as e: + logger.exception( + "Last success checkpointId={}, now occur error.".format( + self.last_checkpoint_id)) + self.request_rollback(str(e)) + + logger.info("Source fetcher thread exit.") self.stopped = True def cancel_task(self): self.running = False while not self.stopped: + time.sleep(0.5) pass @@ -152,22 +319,64 @@ class OneInputStreamTask(InputStreamTask): """A stream task for executing :class:`runtime.processor.OneInputProcessor` """ - def __init__(self, task_id, processor_instance, worker): - super().__init__(task_id, processor_instance, worker) + def __init__(self, task_id, processor_instance, worker, + last_checkpoint_id): + super().__init__(task_id, processor_instance, worker, + last_checkpoint_id) class SourceStreamTask(StreamTask): """A stream task for executing :class:`runtime.processor.SourceProcessor` """ + processor: "SourceProcessor" - def __init__(self, task_id, processor_instance, worker): - super().__init__(task_id, processor_instance, worker) - - def init(self): - pass + def __init__(self, task_id: int, processor_instance: "SourceProcessor", + worker: "JobWorker", last_checkpoint_id): + super().__init__(task_id, processor_instance, worker, + last_checkpoint_id) + self.running = True + self.stopped = False + self.__pending_barrier: Optional[Barrier] = None def run(self): - self.processor.run() + logger.info("Source task thread start.") + try: + while self.running: + self.processor.fetch() + # check checkpoint + if self.__pending_barrier is not None: + # source fetcher only have outputPoints + barrier = self.__pending_barrier + logger.info("Start to do checkpoint {}.".format( + barrier.id)) + self.do_checkpoint(barrier.id, barrier) + logger.info("Finish to do checkpoint {}.".format( + barrier.id)) + self.__pending_barrier = None + + except ChannelInterruptException: + logger.info("queue has stopped.") + except Exception as e: + logger.exception( + "Last success checkpointId={}, now occur error.".format( + self.last_checkpoint_id)) + self.request_rollback(str(e)) + + logger.info("Source fetcher thread exit.") + self.stopped = True + + def commit_trigger(self, barrier): + if self.__pending_barrier is not None: + logger.warning( + "Last barrier is not broadcast now, skip this barrier trigger." + ) + return False + + self.__pending_barrier = barrier + return True def cancel_task(self): - pass + self.running = False + while not self.stopped: + time.sleep(0.5) + pass diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py index 5a19bec9a..8091c1d21 100644 --- a/streaming/python/runtime/transfer.py +++ b/streaming/python/runtime/transfer.py @@ -2,6 +2,8 @@ import logging import random from queue import Queue from typing import List +from enum import Enum +from abc import ABC, abstractmethod import ray import ray.streaming._streaming as _streaming @@ -13,6 +15,7 @@ from ray._raylet import PythonFunctionDescriptor from ray._raylet import Language CHANNEL_ID_LEN = 20 +logger = logging.getLogger(__name__) class ChannelID: @@ -97,40 +100,70 @@ def channel_bytes_to_str(id_bytes): return bytes.hex(id_bytes) -class DataMessage: +class Message(ABC): + @property + @abstractmethod + def body(self): + """Message data""" + pass + + @property + @abstractmethod + def timestamp(self): + """Get timestamp when item is written by upstream DataWriter + """ + pass + + @property + @abstractmethod + def channel_id(self): + """Get string id of channel where data is coming from""" + pass + + @property + @abstractmethod + def message_id(self): + """Get message id of the message""" + pass + + +class DataMessage(Message): """ - DataMessage represents data between upstream and downstream operator + DataMessage represents data between upstream and downstream operator. """ def __init__(self, body, timestamp, + message_id, channel_id, - message_id_, is_empty_message=False): self.__body = body self.__timestamp = timestamp self.__channel_id = channel_id - self.__message_id = message_id_ + self.__message_id = message_id self.__is_empty_message = is_empty_message def __len__(self): return len(self.__body) + @property def body(self): - """Message data""" return self.__body + @property def timestamp(self): - """Get timestamp when item is written by upstream DataWriter - """ return self.__timestamp + @property def channel_id(self): - """Get string id of channel where data is coming from - """ return self.__channel_id + @property + def message_id(self): + return self.__message_id + + @property def is_empty_message(self): """Whether this message is an empty message. Upstream DataWriter will send an empty message when this is no data @@ -138,10 +171,47 @@ class DataMessage: """ return self.__is_empty_message + +class CheckpointBarrier(Message): + """ + CheckpointBarrier separates the records in the data stream into the set of + records that goes into the current snapshot, and the records that go into + the next snapshot. Each barrier carries the ID of the snapshot whose + records it pushed in front of it. + """ + + def __init__(self, barrier_data, timestamp, message_id, channel_id, + offsets, barrier_id, barrier_type): + self.__barrier_data = barrier_data + self.__timestamp = timestamp + self.__message_id = message_id + self.__channel_id = channel_id + self.checkpoint_id = barrier_id + self.offsets = offsets + self.barrier_type = barrier_type + + @property + def body(self): + return self.__barrier_data + + @property + def timestamp(self): + return self.__timestamp + + @property + def channel_id(self): + return self.__channel_id + @property def message_id(self): return self.__message_id + def get_input_checkpoints(self): + return self.offsets + + def __str__(self): + return "Barrier(Checkpoint id : {})".format(self.checkpoint_id) + class ChannelCreationParametersBuilder: """ @@ -218,9 +288,6 @@ class ChannelCreationParametersBuilder: _python_reader_sync_function_descriptor = sync_function -logger = logging.getLogger(__name__) - - class DataWriter: """Data Writer is a wrapper of streaming c++ DataWriter, which sends data to downstream workers @@ -264,6 +331,26 @@ class DataWriter: msg_id = self.writer.write(channel_id.object_qid, item) return msg_id + def broadcast_barrier(self, checkpoint_id: int, body: bytes): + """Broadcast barriers to all downstream channels + Args: + checkpoint_id: the checkpoint_id + body: barrier payload + """ + self.writer.broadcast_barrier(checkpoint_id, body) + + def get_output_checkpoints(self) -> List[int]: + """Get output offsets of all downstream channels + Returns: + a list contains current msg_id of each downstream channel + """ + return self.writer.get_output_checkpoints() + + def clear_checkpoint(self, checkpoint_id): + logger.info("producer start to clear checkpoint, checkpoint_id={}" + .format(checkpoint_id)) + self.writer.clear_checkpoint(checkpoint_id) + def stop(self): logger.info("stopping channel writer.") self.writer.stop() @@ -294,18 +381,20 @@ class DataReader: ] creation_parameters = ChannelCreationParametersBuilder() creation_parameters.build_input_queue_parameters(from_actors) - py_seq_ids = [0 for _ in range(len(input_channels))] py_msg_ids = [0 for _ in range(len(input_channels))] timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1)) - is_recreate = bool(conf.get(Config.IS_RECREATE, False)) config_bytes = _to_native_conf(conf) self.__queue = Queue(10000) is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL - self.reader = _streaming.DataReader.create( + self.reader, queues_creation_status = _streaming.DataReader.create( py_input_channels, creation_parameters.get_parameters(), - py_seq_ids, py_msg_ids, timer_interval, is_recreate, config_bytes, - is_mock) - logger.info("create DataReader succeed") + py_msg_ids, timer_interval, config_bytes, is_mock) + + self.__creation_status = {} + for q, status in queues_creation_status.items(): + self.__creation_status[q] = ChannelCreationStatus(status) + logger.info("create DataReader succeed, creation_status={}".format( + self.__creation_status)) def read(self, timeout_millis): """Read data from channel @@ -316,16 +405,17 @@ class DataReader: channel item """ if self.__queue.empty(): - msgs = self.reader.read(timeout_millis) - for msg in msgs: - msg_bytes, msg_id, timestamp, qid_bytes = msg - data_msg = DataMessage(msg_bytes, timestamp, - channel_bytes_to_str(qid_bytes), msg_id) - self.__queue.put(data_msg) + messages = self.reader.read(timeout_millis) + for message in messages: + self.__queue.put(message) + if self.__queue.empty(): return None return self.__queue.get() + def get_channel_recover_info(self): + return ChannelRecoverInfo(self.__creation_status) + def stop(self): logger.info("stopping Data Reader.") self.reader.stop() @@ -372,3 +462,45 @@ class ChannelInitException(Exception): class ChannelInterruptException(Exception): def __init__(self, msg=None): self.msg = msg + + +class ChannelRecoverInfo: + def __init__(self, queue_creation_status_map=None): + if queue_creation_status_map is None: + queue_creation_status_map = {} + self.__queue_creation_status_map = queue_creation_status_map + + def get_creation_status(self): + return self.__queue_creation_status_map + + def get_data_lost_queues(self): + data_lost_queues = set() + for (q, status) in self.__queue_creation_status_map.items(): + if status == ChannelCreationStatus.DataLost: + data_lost_queues.add(q) + return data_lost_queues + + def __str__(self): + return "QueueRecoverInfo [dataLostQueues=%s]" \ + % (self.get_data_lost_queues()) + + +class ChannelCreationStatus(Enum): + FreshStarted = 0 + PullOk = 1 + Timeout = 2 + DataLost = 3 + + +def channel_id_bytes_to_str(id_bytes): + """ + Args: + id_bytes: bytes representation of channel id + + Returns: + string representation of channel id + """ + assert type(id_bytes) in [str, bytes] + if isinstance(id_bytes, str): + return id_bytes + return bytes.hex(id_bytes) diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py index 0fdb56096..d6d8eb02b 100644 --- a/streaming/python/runtime/worker.py +++ b/streaming/python/runtime/worker.py @@ -1,12 +1,23 @@ -import logging +import enum +import logging.config +import os +import threading +import time +from typing import Optional import ray -import ray.streaming._streaming as _streaming -import ray.streaming.generated.remote_call_pb2 as remote_call_pb import ray.streaming.runtime.processor as processor -from ray.streaming.config import Config -from ray.streaming.runtime.graph import ExecutionVertexContext +from ray.actor import ActorHandle +from ray.streaming.generated import remote_call_pb2 +from ray.streaming.runtime.command import WorkerRollbackRequest +from ray.streaming.runtime.failover import Barrier +from ray.streaming.runtime.graph import ExecutionVertexContext, ExecutionVertex +from ray.streaming.runtime.remote_call import CallResult, RemoteCallMst +from ray.streaming.runtime.context_backend import ContextBackendFactory from ray.streaming.runtime.task import SourceStreamTask, OneInputStreamTask +from ray.streaming.runtime.transfer import channel_bytes_to_str +from ray.streaming.config import Config +import ray.streaming._streaming as _streaming logger = logging.getLogger(__name__) @@ -18,74 +29,179 @@ _NOT_READY_FLAG_ = b" " * 4 class JobWorker(object): """A streaming job worker is used to execute user-defined function and interact with `JobMaster`""" + master_actor: Optional[ActorHandle] + worker_context: Optional[remote_call_pb2.PythonJobWorkerContext] + execution_vertex_context: Optional[ExecutionVertexContext] + __need_rollback: bool - def __init__(self): + def __init__(self, execution_vertex_pb_bytes): + logger.info("Creating job worker, pid={}".format(os.getpid())) + execution_vertex_pb = remote_call_pb2\ + .ExecutionVertexContext.ExecutionVertex() + execution_vertex_pb.ParseFromString(execution_vertex_pb_bytes) + self.execution_vertex = ExecutionVertex(execution_vertex_pb) + self.config = self.execution_vertex.config self.worker_context = None self.execution_vertex_context = None - self.config = None self.task_id = None self.task = None self.stream_processor = None + self.master_actor = None + self.context_backend = ContextBackendFactory.get_context_backend( + self.config) + self.initial_state_lock = threading.Lock() + self.__rollback_cnt: int = 0 + self.__is_recreate: bool = False + self.__state = WorkerState() + self.__need_rollback = True self.reader_client = None self.writer_client = None - logger.info("Creating job worker succeeded.") + try: + # load checkpoint + was_reconstructed = ray.get_runtime_context( + ).was_current_actor_reconstructed + + logger.info( + "Worker was reconstructed: {}".format(was_reconstructed)) + if was_reconstructed: + job_worker_context_key = self.__get_job_worker_context_key() + logger.info("Worker get checkpoint state by key: {}.".format( + job_worker_context_key)) + context_bytes = self.context_backend.get( + job_worker_context_key) + if context_bytes is not None and context_bytes.__len__() > 0: + self.init(context_bytes) + self.request_rollback( + "Python worker recover from checkpoint.") + else: + logger.error( + "Error! Worker get checkpoint state by key {}" + " returns None, please check your state backend" + ", only reliable state backend supports fail-over." + .format(job_worker_context_key)) + except Exception: + logger.exception("Error in __init__ of JobWorker") + logger.info("Creating job worker succeeded. worker config {}".format( + self.config)) def init(self, worker_context_bytes): - worker_context = remote_call_pb.PythonJobWorkerContext() - worker_context.ParseFromString(worker_context_bytes) - self.worker_context = worker_context + logger.info("Start to init job worker") + try: + # deserialize context + worker_context = remote_call_pb2.PythonJobWorkerContext() + worker_context.ParseFromString(worker_context_bytes) + self.worker_context = worker_context + self.master_actor = ActorHandle._deserialization_helper( + worker_context.master_actor) - # build vertex context from pb - self.execution_vertex_context = ExecutionVertexContext( - worker_context.execution_vertex_context) + # build vertex context from pb + self.execution_vertex_context = ExecutionVertexContext( + worker_context.execution_vertex_context) + self.execution_vertex = self\ + .execution_vertex_context.execution_vertex - # use vertex id as task id - self.task_id = self.execution_vertex_context.get_task_id() + # save context + job_worker_context_key = self.__get_job_worker_context_key() + self.context_backend.put(job_worker_context_key, + worker_context_bytes) - # build and get processor from operator - operator = self.execution_vertex_context.stream_operator - self.stream_processor = processor.build_processor(operator) - logger.info( - "Initializing job worker, task_id: {}, operator: {}.".format( - self.task_id, self.stream_processor)) + # use vertex id as task id + self.task_id = self.execution_vertex_context.get_task_id() + # build and get processor from operator + operator = self.execution_vertex_context.stream_operator + self.stream_processor = processor.build_processor(operator) + logger.info("Initializing job worker, exe_vertex_name={}," + "task_id: {}, operator: {}, pid={}".format( + self.execution_vertex_context.exe_vertex_name, + self.task_id, self.stream_processor, os.getpid())) - # get config from vertex - self.config = self.execution_vertex_context.config + # get config from vertex + self.config = self.execution_vertex_context.config - if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): - self.reader_client = _streaming.ReaderClient() - self.writer_client = _streaming.WriterClient() + if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): + self.reader_client = _streaming.ReaderClient() + self.writer_client = _streaming.WriterClient() - self.task = self.create_stream_task() - - logger.info("Job worker init succeeded.") + logger.info("Job worker init succeeded.") + except Exception: + logger.exception("Error when init job worker.") + return False return True - def start(self): - self.task.start() - logger.info("Job worker start succeeded.") - - def create_stream_task(self): + def create_stream_task(self, checkpoint_id): if isinstance(self.stream_processor, processor.SourceProcessor): - return SourceStreamTask(self.task_id, self.stream_processor, self) + return SourceStreamTask(self.task_id, self.stream_processor, self, + checkpoint_id) elif isinstance(self.stream_processor, processor.OneInputProcessor): return OneInputStreamTask(self.task_id, self.stream_processor, - self) + self, checkpoint_id) else: raise Exception("Unsupported processor type: " + - type(self.stream_processor)) + str(type(self.stream_processor))) - def on_reader_message(self, buffer: bytes): + def rollback(self, checkpoint_id_bytes): + checkpoint_id_pb = remote_call_pb2.CheckpointId() + checkpoint_id_pb.ParseFromString(checkpoint_id_bytes) + checkpoint_id = checkpoint_id_pb.checkpoint_id + + logger.info("Start rollback, checkpoint_id={}".format(checkpoint_id)) + + self.__rollback_cnt += 1 + if self.__rollback_cnt > 1: + self.__is_recreate = True + # skip useless rollback + self.initial_state_lock.acquire() + try: + if self.task is not None and self.task.thread.is_alive()\ + and checkpoint_id == self.task.last_checkpoint_id\ + and self.task.is_initial_state: + logger.info( + "Task is already in initial state, skip this rollback.") + return self.__gen_call_result( + CallResult.skipped( + "Task is already in initial state, skip this rollback." + )) + finally: + self.initial_state_lock.release() + + # restart task + try: + if self.task is not None: + # make sure the runner is closed + self.task.cancel_task() + del self.task + + self.task = self.create_stream_task(checkpoint_id) + + q_recover_info = self.task.recover(self.__is_recreate) + + self.__state.set_type(StateType.RUNNING) + self.__need_rollback = False + + logger.info( + "Rollback success, checkpoint is {}, qRecoverInfo is {}.". + format(checkpoint_id, q_recover_info)) + + return self.__gen_call_result(CallResult.success(q_recover_info)) + except Exception: + logger.exception("Rollback has exception.") + return self.__gen_call_result(CallResult.fail()) + + def on_reader_message(self, *buffers): """Called by upstream queue writer to send data message to downstream queue reader. """ - self.reader_client.on_reader_message(buffer) + if self.reader_client is None: + logger.info("reader_client is None, skip writer transfer") + return + self.reader_client.on_reader_message(*buffers) def on_reader_message_sync(self, buffer: bytes): - """Called by upstream queue writer to send control message to downstream - downstream queue reader. + """Called by upstream queue writer to send + control message to downstream downstream queue reader. """ if self.reader_client is None: + logger.info("task is None, skip reader transfer") return _NOT_READY_FLAG_ result = self.reader_client.on_reader_message_sync(buffer) return result.to_pybytes() @@ -94,6 +210,9 @@ class JobWorker(object): """Called by downstream queue reader to send notify message to upstream queue writer. """ + if self.writer_client is None: + logger.info("writer_client is None, skip writer transfer") + return self.writer_client.on_writer_message(buffer) def on_writer_message_sync(self, buffer: bytes): @@ -104,3 +223,164 @@ class JobWorker(object): return _NOT_READY_FLAG_ result = self.writer_client.on_writer_message_sync(buffer) return result.to_pybytes() + + def shutdown_without_reconstruction(self): + logger.info("Python worker shutdown without reconstruction.") + ray.actor.exit_actor() + + def notify_checkpoint_timeout(self, checkpoint_id_bytes): + pass + + def commit(self, barrier_bytes): + barrier_pb = remote_call_pb2.Barrier() + barrier_pb.ParseFromString(barrier_bytes) + barrier = Barrier(barrier_pb.id) + logger.info("Receive trigger, barrier is {}.".format(barrier)) + + if self.task is not None: + self.task.commit_trigger(barrier) + ret = remote_call_pb2.BoolResult() + ret.boolRes = True + return ret.SerializeToString() + + def clear_expired_cp(self, state_checkpoint_id_bytes, + queue_checkpoint_id_bytes): + state_checkpoint_id = self.__parse_to_checkpoint_id( + state_checkpoint_id_bytes) + queue_checkpoint_id = self.__parse_to_checkpoint_id( + queue_checkpoint_id_bytes) + logger.info("Start to clear expired checkpoint, checkpoint_id={}," + "queue_checkpoint_id={}, exe_vertex_name={}.".format( + state_checkpoint_id, queue_checkpoint_id, + self.execution_vertex_context.exe_vertex_name)) + + ret = remote_call_pb2.BoolResult() + ret.boolRes = self.__clear_expired_cp_state(state_checkpoint_id) \ + if state_checkpoint_id > 0 else True + ret.boolRes &= self.__clear_expired_queue_msg(queue_checkpoint_id) + logger.info( + "Clear expired checkpoint done, result={}, checkpoint_id={}," + "queue_checkpoint_id={}, exe_vertex_name={}.".format( + ret.boolRes, state_checkpoint_id, queue_checkpoint_id, + self.execution_vertex_context.exe_vertex_name)) + return ret.SerializeToString() + + def __clear_expired_cp_state(self, checkpoint_id): + if self.__need_rollback: + logger.warning("Need rollback, skip clear_expired_cp_state" + ", checkpoint id: {}".format(checkpoint_id)) + return False + + logger.info("Clear expired checkpoint state, cp id is {}.".format( + checkpoint_id)) + + if self.task is not None: + self.task.clear_expired_cp_state(checkpoint_id) + return True + + def __clear_expired_queue_msg(self, checkpoint_id): + if self.__need_rollback: + logger.warning("Need rollback, skip clear_expired_queue_msg" + ", checkpoint id: {}".format(checkpoint_id)) + return False + + logger.info("Clear expired queue msg, checkpoint_id is {}.".format( + checkpoint_id)) + + if self.task is not None: + self.task.clear_expired_queue_msg(checkpoint_id) + return True + + def __parse_to_checkpoint_id(self, checkpoint_id_bytes): + checkpoint_id_pb = remote_call_pb2.CheckpointId() + checkpoint_id_pb.ParseFromString(checkpoint_id_bytes) + return checkpoint_id_pb.checkpoint_id + + def check_if_need_rollback(self): + ret = remote_call_pb2.BoolResult() + ret.boolRes = self.__need_rollback + return ret.SerializeToString() + + def request_rollback(self, exception_msg="Python exception."): + logger.info("Request rollback.") + + self.__need_rollback = True + self.__is_recreate = True + + request_ret = False + for i in range(Config.REQUEST_ROLLBACK_RETRY_TIMES): + logger.info("request rollback {} time".format(i)) + try: + request_ret = RemoteCallMst.request_job_worker_rollback( + self.master_actor, + WorkerRollbackRequest( + self.execution_vertex_context.actor_id.binary(), + "Exception msg=%s, retry time=%d." % (exception_msg, + i))) + except Exception: + logger.exception("Unexpected error when rollback") + logger.info("request rollback {} time, ret={}".format( + i, request_ret)) + if not request_ret: + logger.warning( + "Request rollback return false" + ", maybe it's invalid request, try to sleep 1s.") + time.sleep(1) + else: + break + if not request_ret: + logger.warning("Request failed after retry {} times," + "now worker shutdown without reconstruction." + .format(Config.REQUEST_ROLLBACK_RETRY_TIMES)) + self.shutdown_without_reconstruction() + + self.__state.set_type(StateType.WAIT_ROLLBACK) + + def __gen_call_result(self, call_result): + call_result_pb = remote_call_pb2.CallResult() + + call_result_pb.success = call_result.success + call_result_pb.result_code = call_result.result_code.value + if call_result.result_msg is not None: + call_result_pb.result_msg = call_result.result_msg + + if call_result.result_obj is not None: + q_recover_info = call_result.result_obj + for q, status in q_recover_info.get_creation_status().items(): + call_result_pb.result_obj.creation_status[channel_bytes_to_str( + q)] = status.value + + return call_result_pb.SerializeToString() + + def _gen_unique_key(self, key_prefix): + return key_prefix \ + + str(self.config.get(Config.STREAMING_JOB_NAME)) \ + + "_" + str(self.execution_vertex.execution_vertex_id) + + def __get_job_worker_context_key(self) -> str: + return self._gen_unique_key(Config.JOB_WORKER_CONTEXT_KEY) + + +class WorkerState: + """ + worker state + """ + + def __init__(self): + self.__type = StateType.INIT + + def set_type(self, type): + self.__type = type + + def get_type(self): + return self.__type + + +class StateType(enum.Enum): + """ + state type + """ + + INIT = 1 + RUNNING = 2 + WAIT_ROLLBACK = 3 diff --git a/streaming/python/tests/test_direct_transfer.py b/streaming/python/tests/test_direct_transfer.py index 5a8866d22..9a9f2892c 100644 --- a/streaming/python/tests/test_direct_transfer.py +++ b/streaming/python/tests/test_direct_transfer.py @@ -68,7 +68,7 @@ class Worker: if item is None: time.sleep(0.01) else: - msg = pickle.loads(item.body()) + msg = pickle.loads(item.body) count += 1 assert msg == msg_nums - 1 print("ReaderWorker done.") diff --git a/streaming/python/tests/test_failover.py b/streaming/python/tests/test_failover.py new file mode 100644 index 000000000..def93f43e --- /dev/null +++ b/streaming/python/tests/test_failover.py @@ -0,0 +1,107 @@ +import subprocess +import time +from typing import List + +import ray +from ray.streaming import StreamingContext + + +def test_word_count(): + try: + ray.init(_load_code_from_local=True, _include_java=True) + # time.sleep(10) # for gdb to attach + ctx = StreamingContext.Builder() \ + .option("streaming.context-backend.type", "local_file") \ + .option( + "streaming.context-backend.file-state.root", + "/tmp/ray/cp_files/" + ) \ + .option("streaming.checkpoint.timeout.secs", "3") \ + .build() + + print("-----------submit job-------------") + + ctx.read_text_file(__file__) \ + .set_parallelism(1) \ + .flat_map(lambda x: x.split()) \ + .map(lambda x: (x, 1)) \ + .key_by(lambda x: x[0]) \ + .reduce(lambda old_value, new_value: + (old_value[0], old_value[1] + new_value[1])) \ + .filter(lambda x: "ray" not in x) \ + .sink(lambda x: print("####result", x)) + ctx.submit("word_count") + + print("-----------checking output-------------") + retry_count = 180 / 5 # wait for 3min + while not has_sink_output(): + time.sleep(5) + retry_count -= 1 + if retry_count <= 0: + raise RuntimeError("Can not find output") + + print("-----------killing worker-------------") + time.sleep(5) + kill_all_worker() + + print("-----------checking checkpoint-------------") + cp_ok_num = checkpoint_success_num() + retry_count = 300000 / 5 # wait for 5min + while True: + cur_cp_num = checkpoint_success_num() + print("-----------checking checkpoint" + ", cur_cp_num={}, old_cp_num={}-------------".format( + cur_cp_num, cp_ok_num)) + if cur_cp_num > cp_ok_num: + print("--------------TEST OK!------------------") + break + time.sleep(5) + retry_count -= 1 + if retry_count <= 0: + raise RuntimeError( + "Checkpoint keeps failing after fail-over, test failed!") + finally: + ray.shutdown() + + +def run_cmd(cmd: List): + try: + out = subprocess.check_output(cmd).decode() + except subprocess.CalledProcessError as e: + out = str(e) + return out + + +def grep_log(keyword: str) -> str: + out = subprocess.check_output( + ["grep", "-r", keyword, "/tmp/ray/session_latest/logs"]) + return out.decode() + + +def has_sink_output() -> bool: + try: + grep_log("####result") + return True + except Exception: + return False + + +def checkpoint_success_num() -> int: + try: + return grep_log("Finish checkpoint").count("\n") + except Exception: + return 0 + + +def kill_all_worker(): + cmd = [ + "bash", "-c", "grep -r \'Initializing job worker, exe_vert\' " + " /tmp/ray/session_latest/logs | awk -F\'pid\' \'{print $2}\'" + "| awk -F\'=\' \'{print $2}\'" + "| xargs kill -9" + ] + print(cmd) + return subprocess.run(cmd) + + +if __name__ == "__main__": + test_word_count() diff --git a/streaming/src/channel.cc b/streaming/src/channel/channel.cc similarity index 73% rename from streaming/src/channel.cc rename to streaming/src/channel/channel.cc index 6816bf972..896ea2169 100644 --- a/streaming/src/channel.cc +++ b/streaming/src/channel/channel.cc @@ -25,30 +25,10 @@ StreamingQueueProducer::~StreamingQueueProducer() { StreamingStatus StreamingQueueProducer::CreateTransferChannel() { CreateQueue(); - uint64_t queue_last_seq_id = 0; - uint64_t last_message_id_in_queue = 0; + STREAMING_LOG(WARNING) << "Message id in channel => " + << channel_info_.current_message_id; - if (!last_message_id_in_queue) { - if (last_message_id_in_queue < channel_info_.current_message_id) { - STREAMING_LOG(WARNING) << "last message id in queue : " << last_message_id_in_queue - << " is less than message checkpoint loaded id : " - << channel_info_.current_message_id - << ", an old queue object " << channel_info_.channel_id - << " was fond in store"; - } - last_message_id_in_queue = channel_info_.current_message_id; - } - if (queue_last_seq_id == static_cast(-1)) { - queue_last_seq_id = 0; - } - channel_info_.current_seq_id = queue_last_seq_id; - - STREAMING_LOG(WARNING) << "existing last message id => " << last_message_id_in_queue - << ", message id in channel => " - << channel_info_.current_message_id << ", queue last seq id => " - << queue_last_seq_id; - - channel_info_.message_last_commit_id = last_message_id_in_queue; + channel_info_.message_last_commit_id = 0; return StreamingStatus::OK; } @@ -69,11 +49,8 @@ StreamingStatus StreamingQueueProducer::CreateQueue() { channel_info_.queue_size); STREAMING_CHECK(queue_ != nullptr); - std::vector queue_ids, failed_queues; - queue_ids.push_back(channel_info_.channel_id); - upstream_handler->WaitQueues(queue_ids, 10 * 1000, failed_queues); - - STREAMING_LOG(INFO) << "q id => " << channel_info_.channel_id << ", queue size => " + STREAMING_LOG(INFO) << "StreamingQueueProducer CreateQueue queue id => " + << channel_info_.channel_id << ", queue size => " << channel_info_.queue_size; return StreamingStatus::OK; @@ -89,21 +66,29 @@ StreamingStatus StreamingQueueProducer::ClearTransferCheckpoint( } StreamingStatus StreamingQueueProducer::RefreshChannelInfo() { - channel_info_.queue_info.consumed_seq_id = queue_->GetMinConsumedSeqID(); + channel_info_.queue_info.consumed_message_id = queue_->GetMinConsumedMsgID(); return StreamingStatus::OK; } -StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t channel_offset) { - queue_->SetQueueEvictionLimit(channel_offset); +StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t msg_id) { + queue_->SetQueueEvictionLimit(msg_id); return StreamingStatus::OK; } StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { - /// TODO: Fix msg_id_start and msg_id_end - Status status = PushQueueItem(channel_info_.current_seq_id + 1, data, data_size, - current_time_ms(), 0, 0); + StreamingMessageBundleMetaPtr meta = StreamingMessageBundleMeta::FromBytes(data); + uint64_t msg_id_end = meta->GetLastMessageId(); + uint64_t msg_id_start = + (meta->GetMessageListSize() == 0 ? msg_id_end + : msg_id_end - meta->GetMessageListSize() + 1); + STREAMING_LOG(DEBUG) << "ProduceItemToChannel, qid=" << channel_info_.channel_id + << ", msg_id_start=" << msg_id_start + << ", msg_id_end=" << msg_id_end << ", meta=" << *meta; + + Status status = + PushQueueItem(data, data_size, current_time_ms(), msg_id_start, msg_id_end); if (status.code() != StatusCode::OK) { STREAMING_LOG(DEBUG) << channel_info_.channel_id << " => Queue is full" << " meesage => " << status.message(); @@ -120,14 +105,14 @@ StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, return StreamingStatus::OK; } -Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data, - uint32_t data_size, uint64_t timestamp, - uint64_t msg_id_start, uint64_t msg_id_end) { +Status StreamingQueueProducer::PushQueueItem(uint8_t *data, uint32_t data_size, + uint64_t timestamp, uint64_t msg_id_start, + uint64_t msg_id_end) { STREAMING_LOG(DEBUG) << "StreamingQueueProducer::PushQueueItem:" - << " qid: " << channel_info_.channel_id << " seq_id: " << seq_id + << " qid: " << channel_info_.channel_id << " data_size: " << data_size; Status status = - queue_->Push(seq_id, data, data_size, timestamp, msg_id_start, msg_id_end, false); + queue_->Push(data, data_size, timestamp, msg_id_start, msg_id_end, false); if (status.IsOutOfMemory()) { status = queue_->TryEvictItems(); if (!status.ok()) { @@ -135,8 +120,7 @@ Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data, return status; } - status = - queue_->Push(seq_id, data, data_size, timestamp, msg_id_start, msg_id_end, false); + status = queue_->Push(data, data_size, timestamp, msg_id_start, msg_id_end, false); } queue_->Send(); @@ -178,7 +162,7 @@ StreamingQueueStatus StreamingQueueConsumer::GetQueue( TransferCreationStatus StreamingQueueConsumer::CreateTransferChannel() { StreamingQueueStatus status = - GetQueue(channel_info_.channel_id, channel_info_.current_seq_id + 1, + GetQueue(channel_info_.channel_id, channel_info_.current_message_id + 1, channel_info_.parameter); if (status == StreamingQueueStatus::OK) { @@ -204,12 +188,11 @@ StreamingStatus StreamingQueueConsumer::ClearTransferCheckpoint( } StreamingStatus StreamingQueueConsumer::RefreshChannelInfo() { - channel_info_.queue_info.last_seq_id = queue_->GetLastRecvSeqId(); + channel_info_.queue_info.last_message_id = queue_->GetLastRecvMsgId(); return StreamingStatus::OK; } -StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint64_t &offset_id, - uint8_t *&data, +StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, uint32_t timeout) { STREAMING_LOG(INFO) << "GetQueueItem qid: " << channel_info_.channel_id; @@ -219,16 +202,14 @@ StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint64_t &offset_ STREAMING_LOG(INFO) << "GetQueueItem timeout."; data = nullptr; data_size = 0; - offset_id = QUEUE_INVALID_SEQ_ID; return StreamingStatus::OK; } data = item.Buffer()->Data(); - offset_id = item.SeqId(); data_size = item.Buffer()->Size(); STREAMING_LOG(DEBUG) << "GetQueueItem qid: " << channel_info_.channel_id - << " seq_id: " << offset_id << " msg_id: " << item.MaxMsgId() + << " seq_id: " << item.SeqId() << " msg_id: " << item.MaxMsgId() << " data_size: " << data_size; return StreamingStatus::OK; } @@ -249,7 +230,7 @@ struct MockQueueItem { class MockQueue { public: std::unordered_map>> - message_bffer; + message_buffer; std::unordered_map>> consumed_buffer; std::unordered_map queue_info_map; @@ -264,7 +245,7 @@ std::mutex MockQueue::mutex; StreamingStatus MockProducer::CreateTransferChannel() { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); - mock_queue.message_bffer[channel_info_.channel_id] = + mock_queue.message_buffer[channel_info_.channel_id] = std::make_shared>(10000); mock_queue.consumed_buffer[channel_info_.channel_id] = std::make_shared>(10000); @@ -274,7 +255,7 @@ StreamingStatus MockProducer::CreateTransferChannel() { StreamingStatus MockProducer::DestroyTransferChannel() { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); - mock_queue.message_bffer.erase(channel_info_.channel_id); + mock_queue.message_buffer.erase(channel_info_.channel_id); mock_queue.consumed_buffer.erase(channel_info_.channel_id); return StreamingStatus::OK; } @@ -282,44 +263,39 @@ StreamingStatus MockProducer::DestroyTransferChannel() { StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); - auto &ring_buffer = mock_queue.message_bffer[channel_info_.channel_id]; + auto &ring_buffer = mock_queue.message_buffer[channel_info_.channel_id]; if (ring_buffer->Full()) { return StreamingStatus::OutOfMemory; } MockQueueItem item; - item.seq_id = channel_info_.current_seq_id + 1; item.data.reset(new uint8_t[data_size]); item.data_size = data_size; std::memcpy(item.data.get(), data, data_size); ring_buffer->Push(item); - mock_queue.queue_info_map[channel_info_.channel_id].last_seq_id = item.seq_id; return StreamingStatus::OK; } StreamingStatus MockProducer::RefreshChannelInfo() { MockQueue &mock_queue = MockQueue::GetMockQueue(); - channel_info_.queue_info.consumed_seq_id = - mock_queue.queue_info_map[channel_info_.channel_id].consumed_seq_id; + channel_info_.queue_info.consumed_message_id = + mock_queue.queue_info_map[channel_info_.channel_id].consumed_message_id; return StreamingStatus::OK; } -StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, - uint32_t &data_size, +StreamingStatus MockConsumer::ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, uint32_t timeout) { std::unique_lock lock(MockQueue::mutex); MockQueue &mock_queue = MockQueue::GetMockQueue(); auto &channel_id = channel_info_.channel_id; - if (mock_queue.message_bffer.find(channel_id) == mock_queue.message_bffer.end()) { + if (mock_queue.message_buffer.find(channel_id) == mock_queue.message_buffer.end()) { return StreamingStatus::NoSuchItem; } - - if (mock_queue.message_bffer[channel_id]->Empty()) { + if (mock_queue.message_buffer[channel_id]->Empty()) { return StreamingStatus::NoSuchItem; } - MockQueueItem item = mock_queue.message_bffer[channel_id]->Front(); - mock_queue.message_bffer[channel_id]->Pop(); + MockQueueItem item = mock_queue.message_buffer[channel_id]->Front(); + mock_queue.message_buffer[channel_id]->Pop(); mock_queue.consumed_buffer[channel_id]->Push(item); - offset_id = item.seq_id; data = item.data.get(); data_size = item.data_size; return StreamingStatus::OK; @@ -333,14 +309,14 @@ StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) { while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) { ring_buffer->Pop(); } - mock_queue.queue_info_map[channel_id].consumed_seq_id = offset_id; + mock_queue.queue_info_map[channel_id].consumed_message_id = offset_id; return StreamingStatus::OK; } StreamingStatus MockConsumer::RefreshChannelInfo() { MockQueue &mock_queue = MockQueue::GetMockQueue(); - channel_info_.queue_info.last_seq_id = - mock_queue.queue_info_map[channel_info_.channel_id].last_seq_id; + channel_info_.queue_info.last_message_id = + mock_queue.queue_info_map[channel_info_.channel_id].last_message_id; return StreamingStatus::OK; } diff --git a/streaming/src/channel.h b/streaming/src/channel/channel.h similarity index 89% rename from streaming/src/channel.h rename to streaming/src/channel/channel.h index 4bb46aa15..733a07042 100644 --- a/streaming/src/channel.h +++ b/streaming/src/channel/channel.h @@ -1,9 +1,10 @@ #pragma once +#include "common/status.h" #include "config/streaming_config.h" #include "queue/queue_handler.h" -#include "ring_buffer.h" -#include "status.h" +#include "ring_buffer/ring_buffer.h" +#include "util/config.h" #include "util/streaming_util.h" namespace ray { @@ -19,9 +20,9 @@ enum class TransferCreationStatus : uint32_t { struct StreamingQueueInfo { uint64_t first_seq_id = 0; - uint64_t last_seq_id = 0; - uint64_t target_seq_id = 0; - uint64_t consumed_seq_id = 0; + uint64_t last_message_id = 0; + uint64_t target_message_id = 0; + uint64_t consumed_message_id = 0; }; struct ChannelCreationParameter { @@ -36,7 +37,6 @@ struct ProducerChannelInfo { ObjectID channel_id; StreamingRingBufferPtr writer_ring_buffer; uint64_t current_message_id; - uint64_t current_seq_id; uint64_t message_last_commit_id; StreamingQueueInfo queue_info; uint32_t queue_size; @@ -58,7 +58,6 @@ struct ProducerChannelInfo { struct ConsumerChannelInfo { ObjectID channel_id; uint64_t current_message_id; - uint64_t current_seq_id; uint64_t barrier_id; uint64_t partial_barrier_id; @@ -71,6 +70,7 @@ struct ConsumerChannelInfo { ChannelCreationParameter parameter; // Total count of notify request. uint64_t notify_cnt = 0; + uint64_t resend_notify_timer; }; /// Two types of channel are presented: @@ -111,8 +111,7 @@ class ConsumerChannel { virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, uint64_t checkpoint_offset) = 0; virtual StreamingStatus RefreshChannelInfo() = 0; - virtual StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, - uint32_t &data_size, + virtual StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, uint32_t timeout) = 0; virtual StreamingStatus NotifyChannelConsumed(uint64_t offset_id) = 0; @@ -136,8 +135,8 @@ class StreamingQueueProducer : public ProducerChannel { private: StreamingStatus CreateQueue(); - Status PushQueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, - uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end); + Status PushQueueItem(uint8_t *data, uint32_t data_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end); private: std::shared_ptr queue_; @@ -153,8 +152,8 @@ class StreamingQueueConsumer : public ConsumerChannel { StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, uint64_t checkpoint_offset) override; StreamingStatus RefreshChannelInfo() override; - StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, - uint32_t &data_size, uint32_t timeout) override; + StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, + uint32_t timeout) override; StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; private: @@ -204,8 +203,8 @@ class MockConsumer : public ConsumerChannel { return StreamingStatus::OK; } StreamingStatus RefreshChannelInfo() override; - StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, - uint32_t &data_size, uint32_t timeout) override; + StreamingStatus ConsumeItemFromChannel(uint8_t *&data, uint32_t &data_size, + uint32_t timeout) override; StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; }; diff --git a/streaming/src/status.h b/streaming/src/common/status.h similarity index 97% rename from streaming/src/status.h rename to streaming/src/common/status.h index dde6f386a..63a1cbaee 100644 --- a/streaming/src/status.h +++ b/streaming/src/common/status.h @@ -19,7 +19,6 @@ enum class StreamingStatus : uint32_t { GetBundleTimeOut = 9, SkipSendEmptyMessage = 10, Interrupted = 11, - WaitQueueTimeOut = 12, OutOfMemory = 13, Invalid = 14, UnknownError = 15, diff --git a/streaming/src/config/streaming_config.cc b/streaming/src/config/streaming_config.cc index f63b00d2e..7dc94c865 100644 --- a/streaming/src/config/streaming_config.cc +++ b/streaming/src/config/streaming_config.cc @@ -10,6 +10,7 @@ uint32_t StreamingConfig::DEFAULT_RING_BUFFER_CAPACITY = 500; uint32_t StreamingConfig::DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL = 20; // Time to force clean if barrier in queue, default 0ms const uint32_t StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE = 2048; +const uint32_t StreamingConfig::RESEND_NOTIFY_MAX_INTERVAL = 1000; // ms #define RESET_IF_INT_CONF(KEY, VALUE) \ if (0 != VALUE) { \ diff --git a/streaming/src/config/streaming_config.h b/streaming/src/config/streaming_config.h index 8f0cf2d5b..784f7b094 100644 --- a/streaming/src/config/streaming_config.h +++ b/streaming/src/config/streaming_config.h @@ -9,12 +9,20 @@ namespace ray { namespace streaming { +using ReliabilityLevel = proto::ReliabilityLevel; +using StreamingRole = proto::NodeType; + +#define DECL_GET_SET_PROPERTY(TYPE, NAME, VALUE) \ + TYPE Get##NAME() const { return VALUE; } \ + void Set##NAME(TYPE value) { VALUE = value; } + class StreamingConfig { public: static uint64_t TIME_WAIT_UINT; static uint32_t DEFAULT_RING_BUFFER_CAPACITY; static uint32_t DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL; static const uint32_t MESSAGE_BUNDLE_MAX_SIZE; + static const uint32_t RESEND_NOTIFY_MAX_INTERVAL; private: uint32_t ring_buffer_capacity_ = DEFAULT_RING_BUFFER_CAPACITY; @@ -40,12 +48,18 @@ class StreamingConfig { uint32_t event_driven_flow_control_interval_ = 1; + ReliabilityLevel streaming_strategy_ = ReliabilityLevel::EXACTLY_ONCE; + StreamingRole streaming_role = StreamingRole::TRANSFORM; + public: void FromProto(const uint8_t *, uint32_t size); -#define DECL_GET_SET_PROPERTY(TYPE, NAME, VALUE) \ - TYPE Get##NAME() const { return VALUE; } \ - void Set##NAME(TYPE value) { VALUE = value; } + inline bool IsAtLeastOnce() const { + return ReliabilityLevel::AT_LEAST_ONCE == streaming_strategy_; + } + inline bool IsExactlyOnce() const { + return ReliabilityLevel::EXACTLY_ONCE == streaming_strategy_; + } DECL_GET_SET_PROPERTY(const std::string &, WorkerName, worker_name_) DECL_GET_SET_PROPERTY(const std::string &, OpName, op_name_) @@ -58,6 +72,8 @@ class StreamingConfig { flow_control_type_) DECL_GET_SET_PROPERTY(uint32_t, EventDrivenFlowControlInterval, event_driven_flow_control_interval_) + DECL_GET_SET_PROPERTY(StreamingRole, StreamingRole, streaming_role) + DECL_GET_SET_PROPERTY(ReliabilityLevel, ReliabilityLevel, streaming_strategy_) uint32_t GetRingBufferCapacity() const; /// Note(lingxuan.zlx), RingBufferCapacity's valid range is from 1 to diff --git a/streaming/src/data_reader.cc b/streaming/src/data_reader.cc index ca59cf987..cf7238201 100644 --- a/streaming/src/data_reader.cc +++ b/streaming/src/data_reader.cc @@ -18,15 +18,16 @@ const uint32_t DataReader::kReadItemTimeout = 1000; void DataReader::Init(const std::vector &input_ids, const std::vector &init_params, - const std::vector &queue_seq_ids, const std::vector &streaming_msg_ids, + std::vector &creation_status, int64_t timer_interval) { Init(input_ids, init_params, timer_interval); for (size_t i = 0; i < input_ids.size(); ++i) { auto &q_id = input_ids[i]; - channel_info_map_[q_id].current_seq_id = queue_seq_ids[i]; + last_message_id_[q_id] = streaming_msg_ids[i]; channel_info_map_[q_id].current_message_id = streaming_msg_ids[i]; } + InitChannel(creation_status); } void DataReader::Init(const std::vector &input_ids, @@ -53,19 +54,23 @@ void DataReader::Init(const std::vector &input_ids, channel_info.last_queue_item_latency = 0; channel_info.last_queue_target_diff = 0; channel_info.get_queue_item_times = 0; + channel_info.resend_notify_timer = 0; } + reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( + runtime_context_->GetConfig(), barrier_helper_, nullptr, this); + /// Make the input id location stable. sort(input_queue_ids_.begin(), input_queue_ids_.end(), [](const ObjectID &a, const ObjectID &b) { return a.Hash() < b.Hash(); }); std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(unready_queue_ids_)); - InitChannel(); } -StreamingStatus DataReader::InitChannel() { +StreamingStatus DataReader::InitChannel( + std::vector &creation_status) { STREAMING_LOG(INFO) << "[Reader] Getting queues. total queue num " - << input_queue_ids_.size() << ", unready queue num => " - << unready_queue_ids_.size(); + << input_queue_ids_.size() + << ", unready queue num=" << unready_queue_ids_.size(); for (const auto &input_channel : unready_queue_ids_) { auto &channel_info = channel_info_map_[input_channel]; @@ -78,8 +83,10 @@ StreamingStatus DataReader::InitChannel() { channel_map_.emplace(input_channel, channel); TransferCreationStatus status = channel->CreateTransferChannel(); + creation_status.push_back(status); if (TransferCreationStatus::PullOk != status) { - STREAMING_LOG(ERROR) << "Initialize queue failed, id => " << input_channel; + STREAMING_LOG(ERROR) << "Initialize queue failed, id=" << input_channel + << ", status=" << static_cast(status); } } runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); @@ -87,10 +94,11 @@ StreamingStatus DataReader::InitChannel() { return StreamingStatus::OK; } -StreamingStatus DataReader::InitChannelMerger() { +StreamingStatus DataReader::InitChannelMerger(uint32_t timeout_ms) { STREAMING_LOG(INFO) << "[Reader] Initializing queue merger."; // Init reader merger by given comparator when it's first created. - StreamingReaderMsgPtrComparator comparator; + StreamingReaderMsgPtrComparator comparator( + runtime_context_->GetConfig().GetReliabilityLevel()); if (!reader_merger_) { reader_merger_.reset( new PriorityQueue, StreamingReaderMsgPtrComparator>( @@ -100,106 +108,255 @@ StreamingStatus DataReader::InitChannelMerger() { // An old item in merger vector must be evicted before new queue item has been // pushed. if (!unready_queue_ids_.empty() && last_fetched_queue_item_) { - STREAMING_LOG(INFO) << "pop old item from => " << last_fetched_queue_item_->from; - RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_)) + STREAMING_LOG(INFO) << "pop old item from=" << last_fetched_queue_item_->from; + RETURN_IF_NOT_OK(StashNextMessageAndPop(last_fetched_queue_item_, timeout_ms)) last_fetched_queue_item_.reset(); } // Create initial heap for priority queue. + std::vector unready_queue_ids_stashed; for (auto &input_queue : unready_queue_ids_) { std::shared_ptr msg = std::make_shared(); - RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info_map_[input_queue], msg)) - channel_info_map_[msg->from].current_seq_id = msg->seq_id; + auto status = GetMessageFromChannel(channel_info_map_[input_queue], msg, timeout_ms, + timeout_ms); + if (StreamingStatus::OK != status) { + STREAMING_LOG(INFO) + << "[Reader] initializing merger, get message from channel timeout, " + << input_queue << ", status => " << static_cast(status); + unready_queue_ids_stashed.push_back(input_queue); + continue; + } channel_info_map_[msg->from].current_message_id = msg->meta->GetLastMessageId(); reader_merger_->push(msg); } - STREAMING_LOG(INFO) << "[Reader] Initializing merger done."; - return StreamingStatus::OK; + if (unready_queue_ids_stashed.empty()) { + STREAMING_LOG(INFO) << "[Reader] Initializing merger done."; + return StreamingStatus::OK; + } else { + STREAMING_LOG(INFO) << "[Reader] Initializing merger unfinished."; + unready_queue_ids_ = unready_queue_ids_stashed; + return StreamingStatus::GetBundleTimeOut; + } } StreamingStatus DataReader::GetMessageFromChannel(ConsumerChannelInfo &channel_info, - std::shared_ptr &message) { + std::shared_ptr &message, + uint32_t timeout_ms, + uint32_t wait_time_ms) { auto &qid = channel_info.channel_id; + message->from = qid; last_read_q_id_ = qid; - STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id => " << qid; - while (RuntimeStatus::Running == runtime_context_->GetRuntimeStatus() && - !message->data) { - auto status = channel_map_[channel_info.channel_id]->ConsumeItemFromChannel( - message->seq_id, message->data, message->data_size, kReadItemTimeout); + + bool is_valid_bundle = false; + int64_t start_time = current_sys_time_ms(); + STREAMING_LOG(DEBUG) << "GetMessageFromChannel, timeout_ms=" << timeout_ms + << ", wait_time_ms=" << wait_time_ms; + while (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Running && + !is_valid_bundle && current_sys_time_ms() - start_time < timeout_ms) { + STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id=" << qid; + /// In AT_LEAST_ONCE, wait_time_ms is set to 0, means `ConsumeItemFromChannel` + /// will return immediately if no items in queue. At the same time, `timeout_ms` is + /// ignored. + channel_map_[channel_info.channel_id]->ConsumeItemFromChannel( + message->data, message->data_size, wait_time_ms); + channel_info.get_queue_item_times++; if (!message->data) { - STREAMING_LOG(DEBUG) << "[Reader] Queue " << qid << " status " << status - << " get item timeout, resend notify " - << channel_info.current_seq_id; - // TODO(lingxuan.zlx): notify consumed when it's timeout. + RETURN_IF_NOT_OK(reliability_helper_->HandleNoValidItem(channel_info)); + } else { + uint64_t current_time = current_sys_time_ms(); + channel_info.resend_notify_timer = current_time; + // Note(lingxuan.zlx): To find which channel get an invalid data and + // print channel id for debugging. + STREAMING_CHECK(StreamingMessageBundleMeta::CheckBundleMagicNum(message->data)) + << "Magic number invalid, from channel " << channel_info.channel_id; + message->meta = StreamingMessageBundleMeta::FromBytes(message->data); + + is_valid_bundle = true; + if (!runtime_context_->GetConfig().IsAtLeastOnce()) { + // filter message when msg_id doesn't match. + // reader will filter message only when using streaming queue and + // non-at-least-once mode + BundleCheckStatus status = CheckBundle(message); + STREAMING_LOG(DEBUG) << "CheckBundle, result=" << status + << ", last_msg_id=" << last_message_id_[message->from]; + if (status == BundleCheckStatus::BundleToBeSplit) { + SplitBundle(message, last_message_id_[qid]); + } + if (status == BundleCheckStatus::BundleToBeThrown && message->meta->IsBarrier()) { + STREAMING_LOG(WARNING) + << "Throw barrier, msg_id=" << message->meta->GetLastMessageId(); + } + is_valid_bundle = status != BundleCheckStatus::BundleToBeThrown; + } } } if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) { return StreamingStatus::Interrupted; } - STREAMING_LOG(DEBUG) << "[Reader] recevied queue seq id => " << message->seq_id - << ", queue id => " << qid; - message->from = qid; - message->meta = StreamingMessageBundleMeta::FromBytes(message->data); + if (!is_valid_bundle) { + STREAMING_LOG(DEBUG) << "GetMessageFromChannel timeout, qid=" + << channel_info.channel_id; + return StreamingStatus::GetBundleTimeOut; + } + + STREAMING_LOG(DEBUG) << "[Reader] received message id=" + << message->meta->GetLastMessageId() << ", queue id=" << qid; + last_message_id_[message->from] = message->meta->GetLastMessageId(); return StreamingStatus::OK; } -StreamingStatus DataReader::StashNextMessage(std::shared_ptr &message) { - // Push new message into priority queue and record the channel metrics in - // channel info. +BundleCheckStatus DataReader::CheckBundle(const std::shared_ptr &message) { + uint64_t end_msg_id = message->meta->GetLastMessageId(); + uint64_t start_msg_id = message->meta->IsEmptyMsg() + ? end_msg_id + : end_msg_id - message->meta->GetMessageListSize() + 1; + uint64_t last_msg_id = last_message_id_[message->from]; + + // Writer will keep sending bundles when downstream reader failover. After reader + // recovered, it will receive these bundles whoes msg_id is larger than expected. + if (start_msg_id > last_msg_id + 1) { + return BundleCheckStatus::BundleToBeThrown; + } + if (end_msg_id < last_msg_id + 1) { + // Empty message and barrier's msg_id equals to last message, so we shouldn't throw + // them. + return end_msg_id == last_msg_id && !message->meta->IsBundle() + ? BundleCheckStatus::OkBundle + : BundleCheckStatus::BundleToBeThrown; + } + // Normal bundles. + if (start_msg_id == last_msg_id + 1) { + return BundleCheckStatus::OkBundle; + } + return BundleCheckStatus::BundleToBeSplit; +} + +void DataReader::SplitBundle(std::shared_ptr &message, uint64_t last_msg_id) { + std::list msg_list; + StreamingMessageBundle::GetMessageListFromRawData( + message->data + kMessageBundleHeaderSize, + message->data_size - kMessageBundleHeaderSize, message->meta->GetMessageListSize(), + msg_list); + uint32_t bundle_size = 0; + for (auto it = msg_list.begin(); it != msg_list.end();) { + if ((*it)->GetMessageId() > last_msg_id) { + bundle_size += (*it)->ClassBytesSize(); + it++; + } else { + it = msg_list.erase(it); + } + } + STREAMING_LOG(DEBUG) << "Split message, from_queue_id=" << message->from + << ", start_msg_id=" << msg_list.front()->GetMessageId() + << ", end_msg_id=" << msg_list.back()->GetMessageId(); + // recreate bundle + auto cut_msg_bundle = std::make_shared( + msg_list, message->meta->GetMessageBundleTs(), msg_list.back()->GetMessageId(), + StreamingMessageBundleType::Bundle, bundle_size); + message->Realloc(cut_msg_bundle->ClassBytesSize()); + cut_msg_bundle->ToBytes(message->data); + message->meta = StreamingMessageBundleMeta::FromBytes(message->data); +} + +StreamingStatus DataReader::StashNextMessageAndPop(std::shared_ptr &message, + uint32_t timeout_ms) { + STREAMING_LOG(DEBUG) << "StashNextMessageAndPop, timeout_ms=" << timeout_ms; + + // Get the first message. + message = reader_merger_->top(); + STREAMING_LOG(DEBUG) << "Messages to be poped=" << *message + << ", merger size=" << reader_merger_->size(); + + // Then stash next message from its from queue. std::shared_ptr new_msg = std::make_shared(); auto &channel_info = channel_info_map_[message->from]; - reader_merger_->pop(); - int64_t cur_time = current_time_ms(); - RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg)) + RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg, timeout_ms, timeout_ms)) + new_msg->last_barrier_id = channel_info.barrier_id; reader_merger_->push(new_msg); + STREAMING_LOG(DEBUG) << "New message pushed=" << *new_msg + << ", merger size=" << reader_merger_->size(); + + // Pop message. + reader_merger_->pop(); + STREAMING_LOG(DEBUG) << "Message poped, msg=" << *message; + + // Record some metrics. channel_info.last_queue_item_delay = new_msg->meta->GetMessageBundleTs() - message->meta->GetMessageBundleTs(); - channel_info.last_queue_item_latency = current_time_ms() - cur_time; + channel_info.last_queue_item_latency = current_time_ms() - current_time_ms(); return StreamingStatus::OK; } StreamingStatus DataReader::GetMergedMessageBundle(std::shared_ptr &message, - bool &is_valid_break) { - int64_t cur_time = current_time_ms(); - if (last_fetched_queue_item_) { - RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_)) - } - message = reader_merger_->top(); - last_fetched_queue_item_ = message; + bool &is_valid_break, + uint32_t timeout_ms) { + RETURN_IF_NOT_OK(StashNextMessageAndPop(message, timeout_ms)) + auto &offset_info = channel_info_map_[message->from]; - uint64_t cur_queue_previous_msg_id = offset_info.current_message_id; - STREAMING_LOG(DEBUG) << "[Reader] [Bundle] from q_id =>" << message->from << "cur => " - << cur_queue_previous_msg_id << ", message list size" - << message->meta->GetMessageListSize() << ", lst message id =>" - << message->meta->GetLastMessageId() << ", q seq id => " - << message->seq_id << ", last barrier id => " << message->data_size - << ", " << message->meta->GetMessageBundleTs(); - + STREAMING_LOG(DEBUG) << "[Reader] [Bundle]" << *message + << ", cur_queue_previous_msg_id=" << cur_queue_previous_msg_id; + int64_t cur_time = current_time_ms(); if (message->meta->IsBundle()) { last_message_ts_ = cur_time; is_valid_break = true; - } else if (timer_interval_ != -1 && cur_time - last_message_ts_ > timer_interval_) { - // Throw empty message when reaching timer_interval. + } else if (message->meta->IsBarrier() && BarrierAlign(message)) { + last_message_ts_ = cur_time; + is_valid_break = true; + } else if (timer_interval_ != -1 && cur_time - last_message_ts_ >= timer_interval_ && + message->meta->IsEmptyMsg()) { + // Sent empty message when reaching timer_interval last_message_ts_ = cur_time; is_valid_break = true; } offset_info.current_message_id = message->meta->GetLastMessageId(); - offset_info.current_seq_id = message->seq_id; last_bundle_ts_ = message->meta->GetMessageBundleTs(); - STREAMING_LOG(DEBUG) << "[Reader] [Bundle] message type =>" - << static_cast(message->meta->GetBundleType()) - << " from id => " << message->from << ", queue seq id =>" - << message->seq_id << ", message id => " - << message->meta->GetLastMessageId(); + STREAMING_LOG(DEBUG) << "[Reader] [Bundle] Get merged message bundle=" << *message + << ", is_valid_break=" << is_valid_break; + last_fetched_queue_item_ = message; return StreamingStatus::OK; } +bool DataReader::BarrierAlign(std::shared_ptr &message) { + // Arrange barrier action when barrier is arriving. + StreamingBarrierHeader barrier_header; + StreamingMessage::GetBarrierIdFromRawData(message->data + kMessageHeaderSize, + &barrier_header); + uint64_t barrier_id = barrier_header.barrier_id; + auto *barrier_align_cnt = &global_barrier_cnt_; + auto &channel_info = channel_info_map_[message->from]; + // Target count is input vector size (global barrier). + uint32_t target_count = 0; + + channel_info.barrier_id = barrier_header.barrier_id; + target_count = input_queue_ids_.size(); + (*barrier_align_cnt)[barrier_id]++; + // The next message checkpoint is changed if this's barrier message. + STREAMING_LOG(INFO) << "[Reader] [Barrier] get barrier, barrier_id=" << barrier_id + << ", barrier_cnt=" << (*barrier_align_cnt)[barrier_id] + << ", global barrier id=" << barrier_header.barrier_id + << ", from q_id=" << message->from << ", barrier type=" + << static_cast(barrier_header.barrier_type) + << ", target count=" << target_count; + // Notify invoker the last barrier, so that checkpoint or something related can be + // taken right now. + if ((*barrier_align_cnt)[barrier_id] == target_count) { + // map can't be used in multithread (crash in report timer) + barrier_align_cnt->erase(barrier_id); + STREAMING_LOG(INFO) + << "[Reader] [Barrier] last barrier received, return barrier. barrier_id = " + << barrier_id << ", from q_id=" << message->from; + return true; + } + return false; +} + StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms, std::shared_ptr &message) { + STREAMING_LOG(DEBUG) << "GetBundle, timeout_ms=" << timeout_ms; // Notify upstream that last fetched item has been consumed. if (last_fetched_queue_item_) { NotifyConsumed(last_fetched_queue_item_); @@ -222,28 +379,25 @@ StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms, return StreamingStatus::GetBundleTimeOut; } if (!unready_queue_ids_.empty()) { - StreamingStatus status = InitChannel(); + std::vector creation_status; + StreamingStatus status = InitChannel(creation_status); switch (status) { case StreamingStatus::InitQueueFailed: break; - case StreamingStatus::WaitQueueTimeOut: - STREAMING_LOG(ERROR) - << "Wait upstream queue timeout, maybe some actors in deadlock"; - break; default: STREAMING_LOG(INFO) << "Init reader queue in GetBundle"; } if (StreamingStatus::OK != status) { return status; } - RETURN_IF_NOT_OK(InitChannelMerger()) + RETURN_IF_NOT_OK(InitChannelMerger(timeout_ms)) unready_queue_ids_.clear(); auto &merge_vec = reader_merger_->getRawVector(); for (auto &bundle : merge_vec) { - STREAMING_LOG(INFO) << "merger vector item => " << bundle->from; + STREAMING_LOG(INFO) << "merger vector item=" << bundle->from; } } - RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break)); + RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break, timeout_ms)); if (!is_valid_break) { empty_bundle_cnt++; NotifyConsumed(message); @@ -261,16 +415,12 @@ void DataReader::GetOffsetInfo( offset_map = &channel_info_map_; for (auto &offset_info : channel_info_map_) { STREAMING_LOG(INFO) << "[Reader] [GetOffsetInfo], q id " << offset_info.first - << ", seq id => " << offset_info.second.current_seq_id - << ", message id => " << offset_info.second.current_message_id; + << ", message id=" << offset_info.second.current_message_id; } } void DataReader::NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset) { channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset); - if (offset == channel_info.queue_info.last_seq_id) { - STREAMING_LOG(DEBUG) << "notify seq id equal to last seq id => " << offset; - } } DataReader::DataReader(std::shared_ptr &runtime_context) @@ -286,37 +436,42 @@ void DataReader::NotifyConsumed(std::shared_ptr &message) { auto &channel_info = channel_info_map_[message->from]; auto &queue_info = channel_info.queue_info; channel_info.notify_cnt++; - if (queue_info.target_seq_id <= message->seq_id) { - NotifyConsumedItem(channel_info, message->seq_id); + if (queue_info.target_message_id <= message->meta->GetLastMessageId()) { + NotifyConsumedItem(channel_info, message->meta->GetLastMessageId()); channel_map_[channel_info.channel_id]->RefreshChannelInfo(); - if (queue_info.last_seq_id != QUEUE_INVALID_SEQ_ID) { - uint64_t original_target_seq_id = queue_info.target_seq_id; - queue_info.target_seq_id = std::min( - queue_info.last_seq_id, - message->seq_id + runtime_context_->GetConfig().GetReaderConsumedStep()); + if (queue_info.last_message_id != QUEUE_INVALID_SEQ_ID) { + uint64_t original_target_message_id = queue_info.target_message_id; + queue_info.target_message_id = + std::min(queue_info.last_message_id, + message->meta->GetLastMessageId() + + runtime_context_->GetConfig().GetReaderConsumedStep()); channel_info.last_queue_target_diff = - queue_info.target_seq_id - original_target_seq_id; + queue_info.target_message_id - original_target_message_id; } else { STREAMING_LOG(WARNING) << "[Reader] [QueueInfo] channel id " << message->from - << ", last seq id " << queue_info.last_seq_id; + << ", last message id " << queue_info.last_message_id; } STREAMING_LOG(DEBUG) << "[Reader] [Consumed] Trigger notify consumed" - << ", channel id => " << message->from << ", last seq id => " - << queue_info.last_seq_id << ", target seq id => " - << queue_info.target_seq_id << ", consumed seq id => " - << message->seq_id << ", last message id => " - << message->meta->GetLastMessageId() << ", bundle type => " + << ", channel id=" << message->from + << ", last message id=" << queue_info.last_message_id + << ", target message id=" << queue_info.target_message_id + << ", consumed message id=" << message->meta->GetLastMessageId() + << ", bundle type=" << static_cast(message->meta->GetBundleType()) - << ", last message bundle ts => " + << ", last message bundle ts=" << message->meta->GetMessageBundleTs(); } } bool StreamingReaderMsgPtrComparator::operator()(const std::shared_ptr &a, const std::shared_ptr &b) { + if (comp_strategy == ReliabilityLevel::EXACTLY_ONCE) { + if (a->last_barrier_id != b->last_barrier_id) + return a->last_barrier_id > b->last_barrier_id; + } STREAMING_CHECK(a->meta); - // We use hash value of id for stability of message in sorting. + // We proposed fixed id sequnce for stability of message in sorting. if (a->meta->GetMessageBundleTs() == b->meta->GetMessageBundleTs()) { return a->from.Hash() > b->from.Hash(); } diff --git a/streaming/src/data_reader.h b/streaming/src/data_reader.h index ac0d1c469..615047ecc 100644 --- a/streaming/src/data_reader.h +++ b/streaming/src/data_reader.h @@ -7,27 +7,38 @@ #include #include -#include "channel.h" +#include "channel/channel.h" #include "message/message_bundle.h" #include "message/priority_queue.h" +#include "reliability/barrier_helper.h" +#include "reliability_helper.h" #include "runtime_context.h" namespace ray { namespace streaming { -/// Databundle is super-bundle that contains channel information (upstream -/// channel id & bundle meta data) and raw buffer pointer. -struct DataBundle { - uint8_t *data = nullptr; - uint32_t data_size; - ObjectID from; - uint64_t seq_id; - StreamingMessageBundleMetaPtr meta; +class ReliabilityHelper; +class AtLeastOnceHelper; + +enum class BundleCheckStatus : uint32_t { + OkBundle = 0, + BundleToBeThrown = 1, + BundleToBeSplit = 2 }; +static inline std::ostream &operator<<(std::ostream &os, + const BundleCheckStatus &status) { + os << static_cast::type>(status); + return os; +} + /// This is implementation of merger policy in StreamingReaderMsgPtrComparator. struct StreamingReaderMsgPtrComparator { - StreamingReaderMsgPtrComparator() = default; + explicit StreamingReaderMsgPtrComparator(ReliabilityLevel strategy) + : comp_strategy(strategy){}; + StreamingReaderMsgPtrComparator(){}; + ReliabilityLevel comp_strategy = ReliabilityLevel::EXACTLY_ONCE; + bool operator()(const std::shared_ptr &a, const std::shared_ptr &b); }; @@ -50,6 +61,8 @@ class DataReader { std::shared_ptr last_fetched_queue_item_; + std::unordered_map global_barrier_cnt_; + int64_t timer_interval_; int64_t last_bundle_ts_; int64_t last_message_ts_; @@ -59,6 +72,12 @@ class DataReader { ObjectID last_read_q_id_; static const uint32_t kReadItemTimeout; + StreamingBarrierHelper barrier_helper_; + std::shared_ptr reliability_helper_; + std::unordered_map last_message_id_; + + friend class ReliabilityHelper; + friend class AtLeastOnceHelper; protected: std::unordered_map channel_info_map_; @@ -73,15 +92,20 @@ class DataReader { /// During initialization, only the channel parameters and necessary member properties /// are assigned. All channels will be connected in the first reading operation. /// \param input_ids - /// \param actor_ids - /// \param channel_seq_ids + /// \param init_params /// \param msg_ids + /// \param[out] creation_status /// \param timer_interval void Init(const std::vector &input_ids, const std::vector &init_params, - const std::vector &channel_seq_ids, - const std::vector &msg_ids, int64_t timer_interval); + const std::vector &msg_ids, + std::vector &creation_status, int64_t timer_interval); + /// Create reader use msg_id=0, this method is public only for test, and users + /// usuallly don't need it. + /// \param input_ids + /// \param init_params + /// \param timer_interval void Init(const std::vector &input_ids, const std::vector &init_params, int64_t timer_interval); @@ -108,22 +132,30 @@ class DataReader { private: /// Create channels and connect to all upstream. - StreamingStatus InitChannel(); + StreamingStatus InitChannel(std::vector &creation_status); /// One item from every channel will be popped out, then collecting /// them to a merged queue. High prioprity items will be fetched one by one. /// When item pop from one channel where must produce new item for placeholder /// in merged queue. - StreamingStatus InitChannelMerger(); + StreamingStatus InitChannelMerger(uint32_t timeout_ms); - StreamingStatus StashNextMessage(std::shared_ptr &message); + StreamingStatus StashNextMessageAndPop(std::shared_ptr &message, + uint32_t timeout_ms); StreamingStatus GetMessageFromChannel(ConsumerChannelInfo &channel_info, - std::shared_ptr &message); + std::shared_ptr &message, + uint32_t timeout_ms, uint32_t wait_time_ms); /// Get top item from prioprity queue. StreamingStatus GetMergedMessageBundle(std::shared_ptr &message, - bool &is_valid_break); + bool &is_valid_break, uint32_t timeout_ms); + + bool BarrierAlign(std::shared_ptr &message); + + BundleCheckStatus CheckBundle(const std::shared_ptr &message); + + static void SplitBundle(std::shared_ptr &message, uint64_t last_msg_id); }; } // namespace streaming } // namespace ray diff --git a/streaming/src/data_writer.cc b/streaming/src/data_writer.cc index 733ee2a34..e28ba26eb 100644 --- a/streaming/src/data_writer.cc +++ b/streaming/src/data_writer.cc @@ -63,7 +63,9 @@ uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *dat uint32_t data_size, StreamingMessageType message_type) { STREAMING_LOG(DEBUG) << "WriteMessageToBufferRing q_id: " << q_id - << " data_size: " << data_size; + << " data_size: " << data_size + << ", message_type=" << static_cast(message_type) + << ", data=" << Util::Byte2hex(data, data_size); // TODO(lingxuan.zlx): currently, unsafe in multithreads ProducerChannelInfo &channel_info = channel_info_map_[q_id]; // Write message id stands for current lastest message id and differs from @@ -89,7 +91,7 @@ uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *dat STREAMING_LOG(DEBUG) << "user_event had been in event_queue"; } else if (!channel_info.flow_control) { channel_info.in_event_queue = true; - Event event{&channel_info, EventType::UserEvent, false}; + Event event(&channel_info, EventType::UserEvent, false); event_service_->Push(event); ++channel_info.user_event_cnt; } @@ -152,6 +154,9 @@ StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, flow_controller_ = std::make_shared(); break; } + + reliability_helper_ = ReliabilityHelperFactory::CreateReliabilityHelper( + runtime_context_->GetConfig(), barrier_helper_, this, nullptr); // Register empty event and user event to event server. event_service_ = std::make_shared(); event_service_->Register( @@ -166,6 +171,48 @@ StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, return StreamingStatus::OK; } +void DataWriter::BroadcastBarrier(uint64_t barrier_id, const uint8_t *data, + uint32_t data_size) { + STREAMING_LOG(INFO) << "broadcast checkpoint id : " << barrier_id; + barrier_helper_.MapBarrierToCheckpoint(barrier_id, barrier_id); + + if (barrier_helper_.Contains(barrier_id)) { + STREAMING_LOG(WARNING) << "replicated global barrier id => " << barrier_id; + return; + } + + std::vector barrier_id_vec; + barrier_helper_.GetAllBarrier(barrier_id_vec); + if (barrier_id_vec.size() > 0) { + // Show all stashed barrier ids that means these checkpoint are not finished + // yet. + STREAMING_LOG(WARNING) << "[Writer] [Barrier] previous barrier(checkpoint) was fail " + "to do some opearting, ids => " + << Util::join(barrier_id_vec.begin(), barrier_id_vec.end(), + "|"); + } + StreamingBarrierHeader barrier_header(StreamingBarrierType::GlobalBarrier, barrier_id); + + auto barrier_payload = + StreamingMessage::MakeBarrierPayload(barrier_header, data, data_size); + auto payload_size = kBarrierHeaderSize + data_size; + for (auto &queue_id : output_queue_ids_) { + uint64_t barrier_message_id = WriteMessageToBufferRing( + queue_id, barrier_payload.get(), payload_size, StreamingMessageType::Barrier); + if (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Interrupted) { + STREAMING_LOG(WARNING) << " stop right now"; + return; + } + + STREAMING_LOG(INFO) << "[Writer] [Barrier] write barrier to => " << queue_id + << ", barrier message id =>" << barrier_message_id + << ", barrier id => " << barrier_id; + } + + STREAMING_LOG(INFO) << "[Writer] [Barrier] global barrier id in runtime => " + << barrier_id; +} + DataWriter::DataWriter(std::shared_ptr &runtime_context) : transfer_config_(new Config()), runtime_context_(runtime_context) {} @@ -237,7 +284,6 @@ StreamingStatus DataWriter::WriteEmptyMessage(ProducerChannelInfo &channel_info) q_ringbuffer->FreeTransientBuffer(); RETURN_IF_NOT_OK(status) - channel_info.current_seq_id++; channel_info.message_pass_by_ts = current_time_ms(); return StreamingStatus::OK; } @@ -248,7 +294,6 @@ StreamingStatus DataWriter::WriteTransientBufferToChannel( StreamingStatus status = channel_map_[channel_info.channel_id]->ProduceItemToChannel( buffer_ptr->GetTransientBufferMutable(), buffer_ptr->GetTransientBufferSize()); RETURN_IF_NOT_OK(status) - channel_info.current_seq_id++; auto transient_bundle_meta = StreamingMessageBundleMeta::FromBytes(buffer_ptr->GetTransientBuffer()); bool is_barrier_bundle = transient_bundle_meta->IsBarrier(); @@ -267,9 +312,23 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, std::list message_list; uint32_t bundle_buffer_size = 0; const uint32_t max_queue_item_size = channel_info.queue_size; + + bool is_barrier = false; + + // Pop until one of the following condition meets: + // 1. ring buffer is empty + // 2. message count in bundle is larger than ring buffer size + // 3. sum of data size of messages in bundle is larger than streaming queue size + // 4. message type changed while (message_list.size() < runtime_context_->GetConfig().GetRingBufferCapacity() && !buffer_ptr->IsEmpty()) { StreamingMessagePtr &message_ptr = buffer_ptr->Front(); + STREAMING_LOG(DEBUG) << "Collecting message " << *message_ptr + << ", message_list_size=" << message_list.size() + << ", buffer capacity=" + << runtime_context_->GetConfig().GetRingBufferCapacity() + << ", buffer size=" << buffer_ptr->Size(); + uint32_t message_total_size = message_ptr->ClassBytesSize(); if (!message_list.empty() && bundle_buffer_size + message_total_size >= max_queue_item_size) { @@ -279,6 +338,11 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, } if (!message_list.empty() && message_list.back()->GetMessageType() != message_ptr->GetMessageType()) { + STREAMING_LOG(DEBUG) << "Different message type detected, break collecting, last " + "message type in list=" + << static_cast(message_list.back()->GetMessageType()) + << ", current collecing message type=" + << static_cast(message_ptr->GetMessageType()); break; } // ClassBytesSize = DataSize + MetaDataSize @@ -287,6 +351,12 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, message_list.push_back(message_ptr); buffer_ptr->Pop(); buffer_remain = buffer_ptr->Size(); + is_barrier = message_ptr->IsBarrier(); + STREAMING_LOG(DEBUG) << "Message " << *message_ptr + << " collected, message_list_size=" << message_list.size() + << ", buffer capacity=" + << runtime_context_->GetConfig().GetRingBufferCapacity() + << ", buffer size=" << buffer_ptr->Size(); } if (bundle_buffer_size >= channel_info.queue_size) { @@ -296,9 +366,16 @@ bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, } StreamingMessageBundlePtr bundle_ptr; + StreamingMessageBundleType bundleType = StreamingMessageBundleType::Bundle; + if (is_barrier) { + bundleType = StreamingMessageBundleType::Barrier; + } bundle_ptr = std::make_shared( - std::move(message_list), current_time_ms(), message_list.back()->GetMessageSeqId(), - StreamingMessageBundleType::Bundle, bundle_buffer_size); + std::move(message_list), current_time_ms(), message_list.back()->GetMessageId(), + bundleType, bundle_buffer_size); + + STREAMING_LOG(DEBUG) << "CollectFromRingBuffer done, bundle=" << *bundle_ptr; + buffer_ptr->ReallocTransientBuffer(bundle_ptr->ClassBytesSize()); bundle_ptr->ToBytes(buffer_ptr->GetTransientBufferMutable()); @@ -386,7 +463,7 @@ void DataWriter::EmptyMessageTimerCallback() { } if (current_ts - channel_info.message_pass_by_ts >= runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) { - Event event{&channel_info, EventType::EmptyEvent, true}; + Event event(&channel_info, EventType::EmptyEvent, true); event_service_->Push(event); ++channel_info.sent_empty_cnt; ++count; @@ -427,14 +504,14 @@ void DataWriter::RefreshChannelAndNotifyConsumed(ProducerChannelInfo &channel_in // Refresh current downstream consumed seq id. channel_map_[channel_info.channel_id]->RefreshChannelInfo(); // Notify the consumed information to local channel. - NotifyConsumedItem(channel_info, channel_info.queue_info.consumed_seq_id); + NotifyConsumedItem(channel_info, channel_info.queue_info.consumed_message_id); } void DataWriter::NotifyConsumedItem(ProducerChannelInfo &channel_info, uint32_t offset) { - if (offset > channel_info.current_seq_id) { + if (offset > channel_info.current_message_id) { STREAMING_LOG(WARNING) << "Can not notify consumed this offset " << offset << " that's out of range, max seq id " - << channel_info.current_seq_id; + << channel_info.current_message_id; } else { channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset); } @@ -472,5 +549,58 @@ void DataWriter::GetOffsetInfo( offset_map = &channel_info_map_; } +void DataWriter::ClearCheckpoint(uint64_t barrier_id) { + if (!barrier_helper_.Contains(barrier_id)) { + STREAMING_LOG(WARNING) << "no such barrier id => " << barrier_id; + return; + } + + std::string global_barrier_id_list_str = "|"; + + for (auto &queue_id : output_queue_ids_) { + uint64_t q_global_barrier_msg_id = 0; + StreamingStatus status = barrier_helper_.GetMsgIdByBarrierId(queue_id, barrier_id, + q_global_barrier_msg_id); + ProducerChannelInfo &channel_info = channel_info_map_[queue_id]; + if (status == StreamingStatus::OK) { + ClearCheckpointId(channel_info, q_global_barrier_msg_id); + } else { + STREAMING_LOG(WARNING) << "no seq record in q => " << queue_id << ", barrier id => " + << barrier_id; + } + global_barrier_id_list_str += + queue_id.Hex() + " : " + std::to_string(q_global_barrier_msg_id) + "| "; + reliability_helper_->CleanupCheckpoint(channel_info, barrier_id); + } + + STREAMING_LOG(INFO) + << "[Writer] [Barrier] [clear] global barrier flag, global barrier id => " + << barrier_id << ", seq id map => " << global_barrier_id_list_str; + + barrier_helper_.ReleaseBarrierMapById(barrier_id); + barrier_helper_.ReleaseBarrierMapCheckpointByBarrierId(barrier_id); +} + +void DataWriter::ClearCheckpointId(ProducerChannelInfo &channel_info, uint64_t msg_id) { + AutoSpinLock lock(notify_flag_); + + uint64_t current_msg_id = channel_info.current_message_id; + if (msg_id > current_msg_id) { + STREAMING_LOG(WARNING) << "current_msg_id=" << current_msg_id + << ", msg_id to be cleared=" << msg_id + << ", channel id = " << channel_info.channel_id; + } + channel_map_[channel_info.channel_id]->NotifyChannelConsumed(msg_id); + + STREAMING_LOG(DEBUG) << "clearing data from msg_id=" << msg_id + << ", qid= " << channel_info.channel_id; +} + +void DataWriter::GetChannelOffset(std::vector &result) { + for (auto &q_id : output_queue_ids_) { + result.push_back(channel_info_map_[q_id].current_message_id); + } +} + } // namespace streaming } // namespace ray diff --git a/streaming/src/data_writer.h b/streaming/src/data_writer.h index 43e3f7189..e2a18c334 100644 --- a/streaming/src/data_writer.h +++ b/streaming/src/data_writer.h @@ -6,15 +6,18 @@ #include #include -#include "channel.h" +#include "channel/channel.h" #include "config/streaming_config.h" #include "event_service.h" #include "flow_control.h" #include "message/message_bundle.h" +#include "reliability/barrier_helper.h" +#include "reliability_helper.h" #include "runtime_context.h" namespace ray { namespace streaming { +class ReliabilityHelper; /// DataWriter is designed for data transporting between upstream and downstream. /// After the user sends the data, it does not immediately send the data to @@ -57,6 +60,27 @@ class DataWriter { const ObjectID &q_id, uint8_t *data, uint32_t data_size, StreamingMessageType message_type = StreamingMessageType::Message); + /// Send barrier to all channel. note there are user defined data in barrier bundle + /// \param barrier_id + /// \param data + /// \param data_size + /// + void BroadcastBarrier(uint64_t barrier_id, const uint8_t *data, uint32_t data_size); + + /// To relieve stress from large source/input data, we define a new function + /// clear_check_point + /// in producer/writer class. Worker can invoke this function if and only if + /// notify_consumed each item + /// flag is passed in reader/consumer, which means writer's producing became more + /// rhythmical and reader + /// can't walk on old way anymore. + /// \param barrier_id: user-defined numerical checkpoint id + void ClearCheckpoint(uint64_t barrier_id); + + /// Replay all queue from checkpoint, it's useful under FO + /// \param result offset vector + void GetChannelOffset(std::vector &result); + void Run(); void Stop(); @@ -112,6 +136,8 @@ class DataWriter { void FlowControlTimer(); + void ClearCheckpointId(ProducerChannelInfo &channel_info, uint64_t seq_id); + private: std::shared_ptr event_service_; @@ -124,6 +150,15 @@ class DataWriter { // unnecessary overflow. std::shared_ptr flow_controller_; + StreamingBarrierHelper barrier_helper_; + std::shared_ptr reliability_helper_; + + // Make thread-safe between loop thread and user thread. + // High-level runtime send notification about clear checkpoint if global + // checkpoint is finished and low-level will auto flush & evict item memory + // when no more space is available. + std::atomic_flag notify_flag_ = ATOMIC_FLAG_INIT; + protected: std::unordered_map channel_info_map_; /// ProducerChannel is middle broker for data transporting and all downstream diff --git a/streaming/src/event_service.cc b/streaming/src/event_service.cc index 926a4039b..dc1003f9e 100644 --- a/streaming/src/event_service.cc +++ b/streaming/src/event_service.cc @@ -57,6 +57,7 @@ void EventQueue::Pop() { no_full_cv_.notify_all(); } +constexpr int EventQueue::kConditionTimeoutMs; void EventQueue::WaitFor(std::unique_lock &lock) { // To avoid deadlock when EventQueue is empty but is_active is changed in other // thread, Event queue should awaken this condtion variable and check it again. diff --git a/streaming/src/event_service.h b/streaming/src/event_service.h index 298443800..874fff029 100644 --- a/streaming/src/event_service.h +++ b/streaming/src/event_service.h @@ -6,9 +6,8 @@ #include #include -#include "channel.h" -#include "ray/core_worker/core_worker.h" -#include "ring_buffer.h" +#include "channel/channel.h" +#include "ring_buffer/ring_buffer.h" #include "util/streaming_util.h" namespace ray { @@ -39,6 +38,12 @@ struct Event { ProducerChannelInfo *channel_info; EventType type; bool urgent; + Event() = default; + Event(ProducerChannelInfo *channel_info, EventType type, bool urgent) { + this->channel_info = channel_info; + this->type = type; + this->urgent = urgent; + } }; /// Data writer utilizes what's called an event-driven programming model diff --git a/streaming/src/flow_control.cc b/streaming/src/flow_control.cc index 77cde8128..b49d10c81 100644 --- a/streaming/src/flow_control.cc +++ b/streaming/src/flow_control.cc @@ -10,21 +10,21 @@ UnconsumedSeqFlowControl::UnconsumedSeqFlowControl( bool UnconsumedSeqFlowControl::ShouldFlowControl(ProducerChannelInfo &channel_info) { auto &queue_info = channel_info.queue_info; - if (queue_info.target_seq_id <= channel_info.current_seq_id) { + if (queue_info.target_message_id <= channel_info.current_message_id) { channel_map_[channel_info.channel_id]->RefreshChannelInfo(); // Target seq id is maximum upper limit in current condition. - channel_info.queue_info.target_seq_id = - channel_info.queue_info.consumed_seq_id + consumed_step_; - STREAMING_LOG(DEBUG) << "Flow control stop writing to downstream, current max id => " - << channel_info.current_seq_id << ", target seq id => " - << queue_info.target_seq_id << ", consumed_id => " - << queue_info.consumed_seq_id << ", q id => " - << channel_info.channel_id - << ". if this log keeps printing, it means something wrong " - "with queue's info API, or downstream node is not " - "consuming data."; + channel_info.queue_info.target_message_id = + channel_info.queue_info.consumed_message_id + consumed_step_; + STREAMING_LOG(DEBUG) + << "Flow control stop writing to downstream, current message id => " + << channel_info.current_message_id << ", target message id => " + << queue_info.target_message_id << ", consumed_id => " + << queue_info.consumed_message_id << ", q id => " << channel_info.channel_id + << ". if this log keeps printing, it means something wrong " + "with queue's info API, or downstream node is not " + "consuming data."; // Double check after refreshing if target seq id is changed. - if (queue_info.target_seq_id <= channel_info.current_seq_id) { + if (queue_info.target_message_id <= channel_info.current_message_id) { return true; } } diff --git a/streaming/src/flow_control.h b/streaming/src/flow_control.h index 0fdcb3291..005e75b8d 100644 --- a/streaming/src/flow_control.h +++ b/streaming/src/flow_control.h @@ -1,6 +1,6 @@ #pragma once -#include "channel.h" +#include "channel/channel.h" namespace ray { namespace streaming { diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.h deleted file mode 100644 index 839617026..000000000 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.h +++ /dev/null @@ -1,31 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class io_ray_streaming_runtime_transfer_ChannelId */ - -#ifndef _Included_io_ray_streaming_runtime_transfer_ChannelId -#define _Included_io_ray_streaming_runtime_transfer_ChannelId -#ifdef __cplusplus -extern "C" { -#endif -#undef io_ray_streaming_runtime_transfer_ChannelId_ID_LENGTH -#define io_ray_streaming_runtime_transfer_ChannelId_ID_LENGTH 20L -/* - * Class: io_ray_streaming_runtime_transfer_ChannelId - * Method: createNativeId - * Signature: (J)J - */ -JNIEXPORT jlong JNICALL -Java_io_ray_streaming_runtime_transfer_ChannelId_createNativeId(JNIEnv *, jclass, jlong); - -/* - * Class: io_ray_streaming_runtime_transfer_ChannelId - * Method: destroyNativeId - * Signature: (J)V - */ -JNIEXPORT void JNICALL -Java_io_ray_streaming_runtime_transfer_ChannelId_destroyNativeId(JNIEnv *, jclass, jlong); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc index bcbe68d39..ba4a4d405 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.cc @@ -12,15 +12,13 @@ using namespace ray::streaming; JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( JNIEnv *env, jclass, jobject streaming_queue_initial_parameters, - jobjectArray input_channels, jlongArray seq_id_array, jlongArray msg_id_array, - jlong timer_interval, jboolean isRecreate, jbyteArray config_bytes, - jboolean is_mock) { + jobjectArray input_channels, jlongArray msg_id_array, jlong timer_interval, + jobject creation_status, jbyteArray config_bytes, jboolean is_mock) { STREAMING_LOG(INFO) << "[JNI]: create DataReader."; std::vector parameter_vec; ParseChannelInitParameters(env, streaming_queue_initial_parameters, parameter_vec); std::vector input_channels_ids = jarray_to_object_id_vec(env, input_channels); - std::vector seq_ids = LongVectorFromJLongArray(env, seq_id_array).data; std::vector msg_ids = LongVectorFromJLongArray(env, msg_id_array).data; auto ctx = std::make_shared(); @@ -32,8 +30,24 @@ Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( if (is_mock) { ctx->MarkMockTest(); } + + // init reader auto reader = new DataReader(ctx); - reader->Init(input_channels_ids, parameter_vec, seq_ids, msg_ids, timer_interval); + std::vector creation_status_vec; + reader->Init(input_channels_ids, parameter_vec, msg_ids, creation_status_vec, + timer_interval); + + // add creation status to Java's List + jclass array_list_cls = env->GetObjectClass(creation_status); + jclass integer_cls = env->FindClass("java/lang/Integer"); + jmethodID array_list_add = + env->GetMethodID(array_list_cls, "add", "(Ljava/lang/Object;)Z"); + for (auto &status : creation_status_vec) { + jmethodID integer_init = env->GetMethodID(integer_cls, "", "(I)V"); + jobject integer_obj = + env->NewObject(integer_cls, integer_init, static_cast(status)); + env->CallBooleanMethod(creation_status, array_list_add, integer_obj); + } STREAMING_LOG(INFO) << "create native DataReader succeed"; return reinterpret_cast(reader); } @@ -51,8 +65,6 @@ JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBund } else if (StreamingStatus::GetBundleTimeOut == status) { } else if (StreamingStatus::InitQueueFailed == status) { throwRuntimeException(env, "init channel failed"); - } else if (StreamingStatus::WaitQueueTimeOut == status) { - throwRuntimeException(env, "wait channel object timeout"); } if (StreamingStatus::OK != status) { @@ -88,3 +100,34 @@ Java_io_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env, jlong ptr) { delete reinterpret_cast(ptr); } + +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_getOffsetsInfoNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + auto reader = reinterpret_cast(ptr); + std::unordered_map *offset_map = nullptr; + reader->GetOffsetInfo(offset_map); + STREAMING_CHECK(offset_map); + // queue nums + (queue id + seq id + message id) * queue nums + int offset_data_size = + sizeof(uint32_t) + (kUniqueIDSize + sizeof(uint64_t) * 2) * offset_map->size(); + jbyteArray offsets_info = env->NewByteArray(offset_data_size); + int offset = 0; + // total queue nums + auto queue_nums = static_cast(offset_map->size()); + env->SetByteArrayRegion(offsets_info, offset, sizeof(uint32_t), + reinterpret_cast(&queue_nums)); + offset += sizeof(uint32_t); + // queue name & offset + for (auto &p : *offset_map) { + env->SetByteArrayRegion(offsets_info, offset, kUniqueIDSize, + reinterpret_cast(p.first.Data())); + offset += kUniqueIDSize; + // msg_id + env->SetByteArrayRegion(offsets_info, offset, sizeof(uint64_t), + reinterpret_cast(&p.second.current_message_id)); + offset += sizeof(uint64_t); + } + return offsets_info; +} \ No newline at end of file diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h index 221fe3105..43f677d34 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataReader.h @@ -1,3 +1,17 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* DO NOT EDIT THIS FILE - it is machine generated */ #include /* Header for class io_ray_streaming_runtime_transfer_DataReader */ @@ -11,12 +25,12 @@ extern "C" { * Class: io_ray_streaming_runtime_transfer_DataReader * Method: createDataReaderNative * Signature: - * (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[J[JJZ[BZ)J + * (Lio/ray/streaming/runtime/transfer/ChannelCreationParametersBuilder;[[B[JJLjava/util/List;[BZ)J */ JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( - JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlongArray, jlong, jboolean, - jbyteArray, jboolean); + JNIEnv *, jclass, jobject, jobjectArray, jlongArray, jlong, jobject, jbyteArray, + jboolean); /* * Class: io_ray_streaming_runtime_transfer_DataReader @@ -26,6 +40,15 @@ Java_io_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataReader_getBundleNative( JNIEnv *, jobject, jlong, jlong, jlong, jlong); +/* + * Class: io_ray_streaming_runtime_transfer_DataReader + * Method: getOffsetsInfoNative + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataReader_getOffsetsInfoNative(JNIEnv *, jobject, + jlong); + /* * Class: io_ray_streaming_runtime_transfer_DataReader * Method: stopReaderNative diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc index 4c2fa8422..efccdf98e 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.cc @@ -79,4 +79,40 @@ Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env, jlong ptr) { auto *data_writer = reinterpret_cast(ptr); delete data_writer; -} \ No newline at end of file +} + +JNIEXPORT jlongArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_getOutputMsgIdNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + DataWriter *writer_client = reinterpret_cast(ptr); + + std::vector result; + writer_client->GetChannelOffset(result); + + jlongArray jArray = env->NewLongArray(result.size()); + jlong jdata[result.size()]; + for (size_t i = 0; i < result.size(); ++i) { + *(jdata + i) = result[i]; + } + env->SetLongArrayRegion(jArray, 0, result.size(), jdata); + return jArray; +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_broadcastBarrierNative( + JNIEnv *env, jobject thisObj, jlong ptr, jlong checkpointId, jbyteArray data) { + STREAMING_LOG(INFO) << "jni: broadcast barrier, cp_id=" << checkpointId; + RawDataFromJByteArray raw_data(env, data); + DataWriter *writer_client = reinterpret_cast(ptr); + writer_client->BroadcastBarrier(checkpointId, raw_data.data, raw_data.data_size); +} + +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_clearCheckpointNative( + JNIEnv *env, jobject thisObj, jlong ptr, jlong checkpointId) { + STREAMING_LOG(INFO) << "[Producer] jni: clearCheckpoints."; + auto *writer = reinterpret_cast(ptr); + writer->ClearCheckpoint(checkpointId); + STREAMING_LOG(INFO) << "[Producer] clear checkpoint done."; +} diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h index d445638b9..ff6ebb839 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_DataWriter.h @@ -1,3 +1,17 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* DO NOT EDIT THIS FILE - it is machine generated */ #include /* Header for class io_ray_streaming_runtime_transfer_DataWriter */ @@ -44,6 +58,35 @@ JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *, jobject, jlong); +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: getOutputMsgIdNative + * Signature: (J)[J + */ +JNIEXPORT jlongArray JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_getOutputMsgIdNative(JNIEnv *, jobject, + jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: broadcastBarrierNative + * Signature: (JJJ[B)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_broadcastBarrierNative(JNIEnv *, + jobject, jlong, + jlong, + jbyteArray); + +/* + * Class: io_ray_streaming_runtime_transfer_DataWriter + * Method: clearCheckpointNative + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_DataWriter_clearCheckpointNative(JNIEnv *, jobject, + jlong, jlong); + #ifdef __cplusplus } #endif diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h index 24517d7c4..4e5c826f5 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_TransferHandler.h @@ -1,3 +1,17 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + /* DO NOT EDIT THIS FILE - it is machine generated */ #include /* Header for class io_ray_streaming_runtime_transfer_TransferHandler */ @@ -10,7 +24,7 @@ extern "C" { /* * Class: io_ray_streaming_runtime_transfer_TransferHandler * Method: createWriterClientNative - * Signature: (J)J + * Signature: ()J */ JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *, @@ -19,7 +33,7 @@ Java_io_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative( /* * Class: io_ray_streaming_runtime_transfer_TransferHandler * Method: createReaderClientNative - * Signature: (J)J + * Signature: ()J */ JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *, diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.cc b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc similarity index 64% rename from streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.cc rename to streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc index 891d95e81..241db2b45 100644 --- a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_ChannelId.cc +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.cc @@ -1,17 +1,18 @@ -#include "io_ray_streaming_runtime_transfer_ChannelId.h" - +#include "io_ray_streaming_runtime_transfer_channel_ChannelId.h" #include "streaming_jni_common.h" using namespace ray::streaming; -JNIEXPORT jlong JNICALL Java_io_ray_streaming_runtime_transfer_ChannelId_createNativeId( +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_createNativeId( JNIEnv *env, jclass cls, jlong qid_address) { auto id = ray::ObjectID::FromBinary( std::string(reinterpret_cast(qid_address), ray::ObjectID::Size())); return reinterpret_cast(new ray::ObjectID(id)); } -JNIEXPORT void JNICALL Java_io_ray_streaming_runtime_transfer_ChannelId_destroyNativeId( +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_destroyNativeId( JNIEnv *env, jclass cls, jlong native_id_ptr) { auto id = reinterpret_cast(native_id_ptr); STREAMING_CHECK(id != nullptr); diff --git a/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h new file mode 100644 index 000000000..ab1295afd --- /dev/null +++ b/streaming/src/lib/java/io_ray_streaming_runtime_transfer_channel_ChannelId.h @@ -0,0 +1,47 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class io_ray_streaming_runtime_transfer_channel_ChannelId */ + +#ifndef _Included_io_ray_streaming_runtime_transfer_channel_ChannelId +#define _Included_io_ray_streaming_runtime_transfer_channel_ChannelId +#ifdef __cplusplus +extern "C" { +#endif +#undef io_ray_streaming_runtime_transfer_channel_ChannelId_ID_LENGTH +#define io_ray_streaming_runtime_transfer_channel_ChannelId_ID_LENGTH 20L +/* + * Class: io_ray_streaming_runtime_transfer_channel_ChannelId + * Method: createNativeId + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_createNativeId(JNIEnv *, jclass, + jlong); + +/* + * Class: io_ray_streaming_runtime_transfer_channel_ChannelId + * Method: destroyNativeId + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_io_ray_streaming_runtime_transfer_channel_ChannelId_destroyNativeId(JNIEnv *, jclass, + jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/streaming_jni_common.cc b/streaming/src/lib/java/streaming_jni_common.cc index 197f7c251..d2ec13950 100644 --- a/streaming/src/lib/java/streaming_jni_common.cc +++ b/streaming/src/lib/java/streaming_jni_common.cc @@ -136,7 +136,7 @@ std::shared_ptr FunctionDescriptorToRayFunction( &function_descriptor_list); ray::FunctionDescriptor function_descriptor = ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list); - ray::RayFunction ray_function{language, function_descriptor}; + ray::RayFunction ray_function(language, function_descriptor); return std::make_shared(ray_function); } diff --git a/streaming/src/lib/java/streaming_jni_common.h b/streaming/src/lib/java/streaming_jni_common.h index 996d58d21..1b70dbfc7 100644 --- a/streaming/src/lib/java/streaming_jni_common.h +++ b/streaming/src/lib/java/streaming_jni_common.h @@ -4,7 +4,7 @@ #include -#include "channel.h" +#include "channel/channel.h" #include "ray/core_worker/common.h" #include "util/streaming_logging.h" diff --git a/streaming/src/message/message.cc b/streaming/src/message/message.cc index cd44f76ee..3216e3b3a 100644 --- a/streaming/src/message/message.cc +++ b/streaming/src/message/message.cc @@ -10,30 +10,32 @@ namespace ray { namespace streaming { -StreamingMessage::StreamingMessage(std::shared_ptr &data, uint32_t data_size, - uint64_t seq_id, StreamingMessageType message_type) - : message_data_(data), - data_size_(data_size), +StreamingMessage::StreamingMessage(std::shared_ptr &payload_data, + uint32_t payload_size, uint64_t msg_id, + StreamingMessageType message_type) + : payload_(payload_data), + payload_size_(payload_size), message_type_(message_type), - message_id_(seq_id) {} + message_id_(msg_id) {} -StreamingMessage::StreamingMessage(std::shared_ptr &&data, uint32_t data_size, - uint64_t seq_id, StreamingMessageType message_type) - : message_data_(data), - data_size_(data_size), +StreamingMessage::StreamingMessage(std::shared_ptr &&payload_data, + uint32_t payload_size, uint64_t msg_id, + StreamingMessageType message_type) + : payload_(payload_data), + payload_size_(payload_size), message_type_(message_type), - message_id_(seq_id) {} + message_id_(msg_id) {} -StreamingMessage::StreamingMessage(const uint8_t *data, uint32_t data_size, - uint64_t seq_id, StreamingMessageType message_type) - : data_size_(data_size), message_type_(message_type), message_id_(seq_id) { - message_data_.reset(new uint8_t[data_size], std::default_delete()); - std::memcpy(message_data_.get(), data, data_size_); +StreamingMessage::StreamingMessage(const uint8_t *payload_data, uint32_t payload_size, + uint64_t msg_id, StreamingMessageType message_type) + : payload_size_(payload_size), message_type_(message_type), message_id_(msg_id) { + payload_.reset(new uint8_t[payload_size], std::default_delete()); + std::memcpy(payload_.get(), payload_data, payload_size); } StreamingMessage::StreamingMessage(const StreamingMessage &msg) { - data_size_ = msg.data_size_; - message_data_ = msg.message_data_; + payload_size_ = msg.payload_size_; + payload_ = msg.payload_; message_id_ = msg.message_id_; message_type_ = msg.message_type_; } @@ -44,8 +46,8 @@ StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes, uint32_t data_size = *reinterpret_cast(bytes + byte_offset); byte_offset += sizeof(data_size); - uint64_t seq_id = *reinterpret_cast(bytes + byte_offset); - byte_offset += sizeof(seq_id); + uint64_t msg_id = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(msg_id); StreamingMessageType msg_type = *reinterpret_cast(bytes + byte_offset); @@ -54,14 +56,14 @@ StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes, auto buf = new uint8_t[data_size]; std::memcpy(buf, bytes + byte_offset, data_size); auto data_ptr = std::shared_ptr(buf, std::default_delete()); - return std::make_shared(data_ptr, data_size, seq_id, msg_type); + return std::make_shared(data_ptr, data_size, msg_id, msg_type); } void StreamingMessage::ToBytes(uint8_t *serlizable_data) { uint32_t byte_offset = 0; - std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&data_size_), - sizeof(data_size_)); - byte_offset += sizeof(data_size_); + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&payload_size_), + sizeof(payload_size_)); + byte_offset += sizeof(payload_size_); std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&message_id_), sizeof(message_id_)); @@ -71,19 +73,28 @@ void StreamingMessage::ToBytes(uint8_t *serlizable_data) { sizeof(message_type_)); byte_offset += sizeof(message_type_); - std::memcpy(serlizable_data + byte_offset, - reinterpret_cast(message_data_.get()), data_size_); + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(payload_.get()), + payload_size_); - byte_offset += data_size_; + byte_offset += payload_size_; STREAMING_CHECK(byte_offset == this->ClassBytesSize()); } bool StreamingMessage::operator==(const StreamingMessage &message) const { - return GetDataSize() == message.GetDataSize() && - GetMessageSeqId() == message.GetMessageSeqId() && + return PayloadSize() == message.PayloadSize() && + GetMessageId() == message.GetMessageId() && GetMessageType() == message.GetMessageType() && - !std::memcmp(RawData(), message.RawData(), data_size_); + !std::memcmp(Payload(), message.Payload(), PayloadSize()); +} + +std::ostream &operator<<(std::ostream &os, const StreamingMessage &message) { + os << "{" + << " message_type_: " << static_cast(message.GetMessageType()) + << " message_id_: " << message.GetMessageId() + << " payload_size_: " << message.payload_size_ + << " payload_: " << (void *)message.payload_.get() << "}"; + return os; } } // namespace streaming diff --git a/streaming/src/message/message.h b/streaming/src/message/message.h index 66a78e9f4..3052e9aa5 100644 --- a/streaming/src/message/message.h +++ b/streaming/src/message/message.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace ray { @@ -16,52 +17,80 @@ enum class StreamingMessageType : uint32_t { MAX = Message }; +enum class StreamingBarrierType : uint32_t { GlobalBarrier = 0 }; + +struct StreamingBarrierHeader { + StreamingBarrierType barrier_type; + uint64_t barrier_id; + StreamingBarrierHeader() = default; + StreamingBarrierHeader(StreamingBarrierType barrier_type, uint64_t barrier_id) { + this->barrier_type = barrier_type; + this->barrier_id = barrier_id; + } + inline bool IsGlobalBarrier() { + return StreamingBarrierType::GlobalBarrier == barrier_type; + } +}; + constexpr uint32_t kMessageHeaderSize = sizeof(uint32_t) + sizeof(uint64_t) + sizeof(StreamingMessageType); +constexpr uint32_t kBarrierHeaderSize = sizeof(StreamingBarrierType) + sizeof(uint64_t); + /// All messages should be wrapped by this protocol. // DataSize means length of raw data, message id is increasing from [1, +INF]. // MessageType will be used for barrier transporting and checkpoint. /// +----------------+ -/// | DataSize=U32 | +/// | PayloadSize=U32| /// +----------------+ /// | MessageId=U64 | /// +----------------+ /// | MessageType=U32| /// +----------------+ -/// | Data=var | +/// | Payload=var | /// +----------------+ +/// Payload field contains barrier header and carried buffer if message type is +/// global/partial barrier. +/// +/// Barrier's Payload field: +/// +----------------------------+ +/// | StreamingBarrierType=U32 | +/// +----------------------------+ +/// | barrier_id=U64 | +/// +----------------------------+ +/// | carried_buffer=var | +/// +----------------------------+ class StreamingMessage { private: - std::shared_ptr message_data_; - uint32_t data_size_; + std::shared_ptr payload_; + uint32_t payload_size_; StreamingMessageType message_type_; uint64_t message_id_; public: /// Copy raw data from outside shared buffer. - /// \param data raw data from user buffer - /// \param data_size raw data size - /// \param seq_id message id + /// \param payload_ raw data from user buffer + /// \param payload_size_ raw data size + /// \param msg_id message id /// \param message_type - StreamingMessage(std::shared_ptr &data, uint32_t data_size, uint64_t seq_id, - StreamingMessageType message_type); + StreamingMessage(std::shared_ptr &payload_data, uint32_t payload_size, + uint64_t msg_id, StreamingMessageType message_type); /// Move outsite raw data to message data. - /// \param data raw data from user buffer - /// \param data_size raw data size - /// \param seq_id message id + /// \param payload_ raw data from user buffer + /// \param payload_size_ raw data size + /// \param msg_id message id /// \param message_type - StreamingMessage(std::shared_ptr &&data, uint32_t data_size, uint64_t seq_id, - StreamingMessageType message_type); + StreamingMessage(std::shared_ptr &&payload_data, uint32_t payload_size, + uint64_t msg_id, StreamingMessageType message_type); /// Copy raw data from outside buffer. - /// \param data raw data from user buffer - /// \param data_size raw data size - /// \param seq_id message id + /// \param payload_ raw data from user buffer + /// \param payload_size_ raw data size + /// \param msg_id message id /// \param message_type - StreamingMessage(const uint8_t *data, uint32_t data_size, uint64_t seq_id, + StreamingMessage(const uint8_t *payload_data, uint32_t payload_size, uint64_t msg_id, StreamingMessageType message_type); StreamingMessage(const StreamingMessage &); @@ -70,20 +99,44 @@ class StreamingMessage { virtual ~StreamingMessage() = default; - inline uint8_t *RawData() const { return message_data_.get(); } - - inline uint32_t GetDataSize() const { return data_size_; } inline StreamingMessageType GetMessageType() const { return message_type_; } - inline uint64_t GetMessageSeqId() const { return message_id_; } + inline uint64_t GetMessageId() const { return message_id_; } + + inline uint8_t *Payload() const { return payload_.get(); } + + inline uint32_t PayloadSize() const { return payload_size_; } + inline bool IsMessage() { return StreamingMessageType::Message == message_type_; } inline bool IsBarrier() { return StreamingMessageType::Barrier == message_type_; } bool operator==(const StreamingMessage &) const; + static inline std::shared_ptr MakeBarrierPayload( + StreamingBarrierHeader &barrier_header, const uint8_t *data, uint32_t data_size) { + std::shared_ptr ptr(new uint8_t[data_size + kBarrierHeaderSize], + std::default_delete()); + std::memcpy(ptr.get(), &barrier_header.barrier_type, sizeof(StreamingBarrierType)); + std::memcpy(ptr.get() + sizeof(StreamingBarrierType), &barrier_header.barrier_id, + sizeof(uint64_t)); + if (data && data_size > 0) { + std::memcpy(ptr.get() + kBarrierHeaderSize, data, data_size); + } + return ptr; + } + virtual void ToBytes(uint8_t *data); static StreamingMessagePtr FromBytes(const uint8_t *data, bool verifer_check = true); - inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + data_size_; } + inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + payload_size_; } + + static inline void GetBarrierIdFromRawData(const uint8_t *data, + StreamingBarrierHeader *barrier_header) { + barrier_header->barrier_type = *reinterpret_cast(data); + barrier_header->barrier_id = + *reinterpret_cast(data + sizeof(StreamingBarrierType)); + } + + friend std::ostream &operator<<(std::ostream &os, const StreamingMessage &message); }; } // namespace streaming diff --git a/streaming/src/message/message_bundle.cc b/streaming/src/message/message_bundle.cc index 13057c428..629a0613d 100644 --- a/streaming/src/message/message_bundle.cc +++ b/streaming/src/message/message_bundle.cc @@ -63,6 +63,14 @@ bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta *meta) co return operator==(*meta); } +std::ostream &operator<<(std::ostream &os, const StreamingMessageBundleMeta &meta) { + os << "{" + << "last_message_id_: " << meta.last_message_id_ + << ", message_list_size_: " << meta.message_list_size_ + << ", bundle_type_: " << static_cast(meta.bundle_type_) << "}"; + return os; +} + StreamingMessageBundleMeta::StreamingMessageBundleMeta() : bundle_type_(StreamingMessageBundleType::Empty) {} @@ -188,5 +196,13 @@ bool StreamingMessageBundle::operator==(StreamingMessageBundle &bundle) const { bool StreamingMessageBundle::operator==(StreamingMessageBundle *bundle) const { return this->operator==(*bundle); } + +std::ostream &operator<<(std::ostream &os, const DataBundle &bundle) { + os << "{" + << "data: " << (void *)bundle.data << ", data_size: " << bundle.data_size + << ", channel last_barrier_id: " << bundle.last_barrier_id + << ", meta: " << *(bundle.meta) << "}"; + return os; +} } // namespace streaming } // namespace ray diff --git a/streaming/src/message/message_bundle.h b/streaming/src/message/message_bundle.h index a5f8687ca..2cad05e36 100644 --- a/streaming/src/message/message_bundle.h +++ b/streaming/src/message/message_bundle.h @@ -7,6 +7,7 @@ #include #include "message/message.h" +#include "ray/common/id.h" namespace ray { namespace streaming { @@ -83,6 +84,7 @@ class StreamingMessageBundleMeta { inline bool IsBarrier() { return StreamingMessageBundleType::Barrier == bundle_type_; } inline bool IsBundle() { return StreamingMessageBundleType::Bundle == bundle_type_; } + inline bool IsEmptyMsg() { return StreamingMessageBundleType::Empty == bundle_type_; } virtual void ToBytes(uint8_t *data); static StreamingMessageBundleMetaPtr FromBytes(const uint8_t *data, @@ -99,6 +101,9 @@ class StreamingMessageBundleMeta { "," + std::to_string(message_bundle_ts_) + "," + std::to_string(static_cast(bundle_type_)); } + + friend std::ostream &operator<<(std::ostream &os, + const StreamingMessageBundleMeta &meta); }; /// StreamingMessageBundle inherits from metadata class (StreamingMessageBundleMeta) @@ -177,5 +182,30 @@ class StreamingMessageBundle : public StreamingMessageBundleMeta { const std::list &message_list, uint32_t raw_data_size, uint8_t *raw_data); }; + +/// Databundle is super-bundle that contains channel information (upstream +/// channel id & bundle meta data) and raw buffer pointer. +struct DataBundle { + uint8_t *data = nullptr; + uint32_t data_size; + ObjectID from; + uint32_t last_barrier_id; + StreamingMessageBundleMetaPtr meta; + bool is_reallocated = false; + + ~DataBundle() { + if (is_reallocated) { + delete[] data; + } + } + + void Realloc(uint32_t size) { + data = new uint8_t[size]; + is_reallocated = true; + } + + friend std::ostream &operator<<(std::ostream &os, const DataBundle &bundle); +}; + } // namespace streaming } // namespace ray diff --git a/streaming/src/protobuf/remote_call.proto b/streaming/src/protobuf/remote_call.proto index 5e9e2a754..34c2dac7b 100644 --- a/streaming/src/protobuf/remote_call.proto +++ b/streaming/src/protobuf/remote_call.proto @@ -4,6 +4,8 @@ package ray.streaming.proto; import "protobuf/streaming.proto"; +import "google/protobuf/any.proto"; + option java_package = "io.ray.streaming.runtime.generated"; // Execution vertex info, including it's upstream and downstream @@ -22,7 +24,7 @@ message ExecutionVertexContext { // unique id of execution vertex int32 execution_vertex_id = 1; // unique id of execution job vertex - int32 execution_job_vertex_Id = 2; + int32 execution_job_vertex_id = 2; // name of execution job vertex, e.g. 1-SourceOperator string execution_job_vertex_name = 3; // index of execution vertex @@ -56,3 +58,48 @@ message PythonJobWorkerContext { // vertex including it's upstream and downstream ExecutionVertexContext execution_vertex_context = 2; } + +message BoolResult { + bool boolRes = 1; +} + +message Barrier { + int64 id = 1; +} + +message CheckpointId { + int64 checkpoint_id = 1; +} + +message BaseWorkerCmd { + bytes actor_id = 1; // actor id + int64 timestamp = 2; + google.protobuf.Any detail = 3; +} + +message WorkerCommitReport { + int64 commit_checkpoint_id = 1; +} + +message WorkerRollbackRequest { + string exception_msg = 1; + string worker_hostname = 2; + string worker_pid = 3; +} + +message CallResult { + bool success = 1; + int32 result_code = 2; + string result_msg = 3; + QueueRecoverInfo result_obj = 4; +} + +message QueueRecoverInfo { + enum QueueCreationStatus { + FreshStarted = 0; + PullOk = 1; + Timeout = 2; + DataLost = 3; + } + map creation_status = 3; +} \ No newline at end of file diff --git a/streaming/src/protobuf/streaming.proto b/streaming/src/protobuf/streaming.proto index e79143556..9f26c20ef 100644 --- a/streaming/src/protobuf/streaming.proto +++ b/streaming/src/protobuf/streaming.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package ray.streaming.proto; +import "google/protobuf/any.proto"; + option java_package = "io.ray.streaming.runtime.generated"; enum Language { @@ -20,6 +22,12 @@ enum NodeType { SINK = 3; } +enum ReliabilityLevel { + NONE = 0; + AT_LEAST_ONCE = 1; + EXACTLY_ONCE = 2; +} + enum FlowControlType { UNKNOWN_FLOW_CONTROL_TYPE = 0; UnconsumedSeqFlowControl = 1; diff --git a/streaming/src/queue/message.cc b/streaming/src/queue/message.cc index ff1edad2a..b0395a806 100644 --- a/streaming/src/queue/message.cc +++ b/streaming/src/queue/message.cc @@ -90,7 +90,7 @@ std::shared_ptr DataMessage::FromBytes(uint8_t *bytes) { void NotificationMessage::ToProtobuf(std::string *output) { queue::protobuf::StreamingQueueNotificationMsg msg; FillMessageCommon(msg.mutable_common()); - msg.set_seq_id(seq_id_); + msg.set_seq_id(msg_id_); msg.SerializeToString(output); } diff --git a/streaming/src/queue/message.h b/streaming/src/queue/message.h index 42b474af2..9438a4714 100644 --- a/streaming/src/queue/message.h +++ b/streaming/src/queue/message.h @@ -102,19 +102,19 @@ class DataMessage : public Message { class NotificationMessage : public Message { public: NotificationMessage(const ActorID &actor_id, const ActorID &peer_actor_id, - const ObjectID &queue_id, uint64_t seq_id) - : Message(actor_id, peer_actor_id, queue_id), seq_id_(seq_id) {} + const ObjectID &queue_id, uint64_t msg_id) + : Message(actor_id, peer_actor_id, queue_id), msg_id_(msg_id) {} virtual ~NotificationMessage() {} static std::shared_ptr FromBytes(uint8_t *bytes); virtual void ToProtobuf(std::string *output); - inline uint64_t SeqId() { return seq_id_; } + inline uint64_t MsgId() { return msg_id_; } inline queue::protobuf::StreamingQueueMessageType Type() { return type_; } private: - uint64_t seq_id_; + uint64_t msg_id_; const queue::protobuf::StreamingQueueMessageType type_ = queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType; }; diff --git a/streaming/src/queue/queue.cc b/streaming/src/queue/queue.cc index 434420ca3..1885a460e 100644 --- a/streaming/src/queue/queue.cc +++ b/streaming/src/queue/queue.cc @@ -101,9 +101,8 @@ size_t Queue::PendingCount() { return begin->SeqId() - end->SeqId() + 1; } -Status WriterQueue::Push(uint64_t seq_id, uint8_t *buffer, uint32_t buffer_size, - uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end, - bool raw) { +Status WriterQueue::Push(uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, + uint64_t msg_id_start, uint64_t msg_id_end, bool raw) { if (IsPendingFull(buffer_size)) { return Status::OutOfMemory("Queue Push OutOfMemory"); } @@ -113,9 +112,9 @@ Status WriterQueue::Push(uint64_t seq_id, uint8_t *buffer, uint32_t buffer_size, std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - QueueItem item(seq_id, buffer, buffer_size, timestamp, msg_id_start, msg_id_end, raw); + QueueItem item(seq_id_, buffer, buffer_size, timestamp, msg_id_start, msg_id_end, raw); Queue::Push(item); - STREAMING_LOG(DEBUG) << "WriterQueue::Push seq_id_: " << seq_id_; + STREAMING_LOG(DEBUG) << "WriterQueue::Push seq_id: " << seq_id_; seq_id_++; return Status::OK(); } @@ -132,33 +131,41 @@ void WriterQueue::Send() { } Status WriterQueue::TryEvictItems() { - STREAMING_LOG(INFO) << "TryEvictItems"; QueueItem item = FrontProcessed(); - uint64_t first_seq_id = item.SeqId(); - STREAMING_LOG(INFO) << "TryEvictItems first_seq_id: " << first_seq_id - << " min_consumed_id_: " << min_consumed_id_ - << " eviction_limit_: " << eviction_limit_; - if (min_consumed_id_ == QUEUE_INVALID_SEQ_ID || first_seq_id > min_consumed_id_) { + STREAMING_LOG(DEBUG) << "TryEvictItems queue_id: " << queue_id_ << " first_item: (" + << item.MsgIdStart() << "," << item.MsgIdEnd() << ")" + << " min_consumed_msg_id_: " << min_consumed_msg_id_ + << " eviction_limit_: " << eviction_limit_ + << " max_data_size_: " << max_data_size_ + << " data_size_sent_: " << data_size_sent_ + << " data_size_: " << data_size_; + + if (min_consumed_msg_id_ == QUEUE_INVALID_SEQ_ID || + min_consumed_msg_id_ < item.MsgIdEnd()) { return Status::OutOfMemory("The queue is full and some reader doesn't consume"); } - if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || first_seq_id > eviction_limit_) { + if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || eviction_limit_ < item.MsgIdEnd()) { return Status::OutOfMemory("The queue is full and eviction limit block evict"); } - uint64_t evict_target_seq_id = std::min(min_consumed_id_, eviction_limit_); + uint64_t evict_target_msg_id = std::min(min_consumed_msg_id_, eviction_limit_); - while (item.SeqId() <= evict_target_seq_id) { + int count = 0; + while (item.MsgIdEnd() <= evict_target_msg_id) { PopProcessed(); - STREAMING_LOG(INFO) << "TryEvictItems directly " << item.SeqId(); + STREAMING_LOG(INFO) << "TryEvictItems directly " << item.MsgIdEnd(); item = FrontProcessed(); + count++; } + STREAMING_LOG(DEBUG) << count << " items evicted, current item: (" << item.MsgIdStart() + << "," << item.MsgIdEnd() << ")"; return Status::OK(); } void WriterQueue::OnNotify(std::shared_ptr notify_msg) { - STREAMING_LOG(INFO) << "OnNotify target seq_id: " << notify_msg->SeqId(); - min_consumed_id_ = notify_msg->SeqId(); + STREAMING_LOG(INFO) << "OnNotify target msg_id: " << notify_msg->MsgId(); + min_consumed_msg_id_ = notify_msg->MsgId(); } void WriterQueue::ResendItem(QueueItem &item, uint64_t first_seq_id, @@ -273,22 +280,22 @@ void WriterQueue::OnPull( }); } -void ReaderQueue::OnConsumed(uint64_t seq_id) { - STREAMING_LOG(INFO) << "OnConsumed: " << seq_id; +void ReaderQueue::OnConsumed(uint64_t msg_id) { + STREAMING_LOG(INFO) << "OnConsumed: " << msg_id; QueueItem item = FrontProcessed(); - while (item.SeqId() <= seq_id) { + while (item.MsgIdEnd() <= msg_id) { PopProcessed(); item = FrontProcessed(); } - Notify(seq_id); + Notify(msg_id); } -void ReaderQueue::Notify(uint64_t seq_id) { +void ReaderQueue::Notify(uint64_t msg_id) { std::vector task_args; - CreateNotifyTask(seq_id, task_args); + CreateNotifyTask(msg_id, task_args); // SubmitActorTask - NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, seq_id); + NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, msg_id); std::unique_ptr buffer = msg.ToBytes(); transport_->Send(std::move(buffer)); @@ -298,7 +305,10 @@ void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector &task_a void ReaderQueue::OnData(QueueItem &item) { last_recv_seq_id_ = item.SeqId(); - STREAMING_LOG(DEBUG) << "ReaderQueue::OnData seq_id: " << last_recv_seq_id_; + last_recv_msg_id_ = item.MsgIdEnd(); + STREAMING_LOG(DEBUG) << "ReaderQueue::OnData queue_id: " << queue_id_ + << " seq_id: " << last_recv_seq_id_ << " msg_id: (" + << item.MsgIdStart() << "," << item.MsgIdEnd() << ")"; Push(item); } diff --git a/streaming/src/queue/queue.h b/streaming/src/queue/queue.h index b15c3cfe8..758b90934 100644 --- a/streaming/src/queue/queue.h +++ b/streaming/src/queue/queue.h @@ -94,10 +94,10 @@ class Queue { inline size_t Count() { return buffer_queue_.size(); } /// Return item count in pending state. - size_t PendingCount(); + inline size_t PendingCount(); /// Return item count in processed state. - size_t ProcessedCount(); + inline size_t ProcessedCount(); inline ActorID GetActorID() { return actor_id_; } inline ActorID GetPeerActorID() { return peer_actor_id_; } @@ -135,7 +135,7 @@ class WriterQueue : public Queue { peer_actor_id_(peer_actor_id), seq_id_(QUEUE_INITIAL_SEQ_ID), eviction_limit_(QUEUE_INVALID_SEQ_ID), - min_consumed_id_(QUEUE_INVALID_SEQ_ID), + min_consumed_msg_id_(QUEUE_INVALID_SEQ_ID), peer_last_msg_id_(0), peer_last_seq_id_(QUEUE_INVALID_SEQ_ID), transport_(transport), @@ -143,12 +143,14 @@ class WriterQueue : public Queue { is_upstream_first_pull_(true) {} /// Push a continuous buffer into queue, the buffer consists of some messages packed by - /// DataWriter. \param data, the buffer address \param data_size, buffer size \param - /// timestamp, the timestamp when the buffer pushed in \param msg_id_start, the message - /// id of the first message in the buffer \param msg_id_end, the message id of the last - /// message in the buffer \param raw, whether this buffer is raw data, be True only in - /// test - Status Push(uint64_t seq_id, uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, + /// DataWriter. + /// \param data, the buffer address + /// \param data_size, buffer size + /// \param timestamp, the timestamp when the buffer pushed in + /// \param msg_id_start, the message id of the first message in the buffer + /// \param msg_id_end, the message id of the last message in the buffer + /// \param raw, whether this buffer is raw data, be True only in test + Status Push(uint8_t *buffer, uint32_t buffer_size, uint64_t timestamp, uint64_t msg_id_start, uint64_t msg_id_end, bool raw = false); /// Callback function, will be called when downstream queue notifies @@ -167,16 +169,14 @@ class WriterQueue : public Queue { void Send(); /// Called when user pushs item into queue. The count of items - /// can be evicted, determined by eviction_limit_ and min_consumed_id_. + /// can be evicted, determined by eviction_limit_ and min_consumed_msg_id_. Status TryEvictItems(); - void SetQueueEvictionLimit(uint64_t eviction_limit) { - eviction_limit_ = eviction_limit; - } + void SetQueueEvictionLimit(uint64_t msg_id) { eviction_limit_ = msg_id; } uint64_t EvictionLimit() { return eviction_limit_; } - uint64_t GetMinConsumedSeqID() { return min_consumed_id_; } + uint64_t GetMinConsumedMsgID() { return min_consumed_msg_id_; } void SetPeerLastIds(uint64_t msg_id, uint64_t seq_id) { peer_last_msg_id_ = msg_id; @@ -215,7 +215,7 @@ class WriterQueue : public Queue { ActorID peer_actor_id_; uint64_t seq_id_; uint64_t eviction_limit_; - uint64_t min_consumed_id_; + uint64_t min_consumed_msg_id_; uint64_t peer_last_msg_id_; uint64_t peer_last_seq_id_; std::shared_ptr transport_; @@ -238,8 +238,8 @@ class ReaderQueue : public Queue { transport), actor_id_(actor_id), peer_actor_id_(peer_actor_id), - min_consumed_id_(QUEUE_INVALID_SEQ_ID), last_recv_seq_id_(QUEUE_INVALID_SEQ_ID), + last_recv_msg_id_(QUEUE_INVALID_SEQ_ID), transport_(transport) {} /// Delete processed items whose seq id <= seq_id, @@ -252,9 +252,8 @@ class ReaderQueue : public Queue { /// NOTE: this callback function is called in queue thread. void OnResendData(std::shared_ptr msg); - uint64_t GetMinConsumedSeqID() { return min_consumed_id_; } - - uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; } + inline uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; } + inline uint64_t GetLastRecvMsgId() { return last_recv_msg_id_; } private: void Notify(uint64_t seq_id); @@ -263,8 +262,8 @@ class ReaderQueue : public Queue { private: ActorID actor_id_; ActorID peer_actor_id_; - uint64_t min_consumed_id_; uint64_t last_recv_seq_id_; + uint64_t last_recv_msg_id_; std::shared_ptr promise_for_pull_; std::shared_ptr transport_; }; diff --git a/streaming/src/queue/queue_handler.cc b/streaming/src/queue/queue_handler.cc index 40a6033d4..c6ef288be 100644 --- a/streaming/src/queue/queue_handler.cc +++ b/streaming/src/queue/queue_handler.cc @@ -260,7 +260,7 @@ void UpstreamQueueMessageHandler::OnNotify( << queue::protobuf::StreamingQueueMessageType_Name( notify_msg->Type()) << ", maybe queue has been destroyed, ignore it." - << " seq id: " << notify_msg->SeqId(); + << " msg id: " << notify_msg->MsgId(); return; } queue->OnNotify(notify_msg); diff --git a/streaming/src/queue/queue_item.h b/streaming/src/queue/queue_item.h index f3954f346..b63e0eb74 100644 --- a/streaming/src/queue/queue_item.h +++ b/streaming/src/queue/queue_item.h @@ -24,6 +24,7 @@ const uint64_t QUEUE_INITIAL_SEQ_ID = 1; /// LocalMemoryBuffer shared_ptr, which will be sent out by Transport. class QueueItem { public: + QueueItem() = default; /// Construct a QueueItem object. /// \param[in] seq_id the sequential id assigned by DataWriter for a message bundle and /// QueueItem. diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc index cd30955fa..6bb378e20 100644 --- a/streaming/src/queue/transport.cc +++ b/streaming/src/queue/transport.cc @@ -36,7 +36,7 @@ void Transport::SendInternal(std::shared_ptr buffer, } void Transport::Send(std::shared_ptr buffer) { - STREAMING_LOG(INFO) << "Transport::Send buffer size: " << buffer->Size(); + STREAMING_LOG(DEBUG) << "Transport::Send buffer size: " << buffer->Size(); std::vector return_ids; SendInternal(std::move(buffer), async_func_, TASK_OPTION_RETURN_NUM_0, return_ids); } diff --git a/streaming/src/reliability/barrier_helper.cc b/streaming/src/reliability/barrier_helper.cc new file mode 100644 index 000000000..14d66b790 --- /dev/null +++ b/streaming/src/reliability/barrier_helper.cc @@ -0,0 +1,165 @@ +#include "barrier_helper.h" + +#include + +#include "util/streaming_logging.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { +StreamingStatus StreamingBarrierHelper::GetMsgIdByBarrierId(const ObjectID &q_id, + uint64_t barrier_id, + uint64_t &msg_id) { + std::lock_guard lock(global_barrier_mutex_); + auto queue_map = global_barrier_map_.find(barrier_id); + if (queue_map == global_barrier_map_.end()) { + return StreamingStatus::NoSuchItem; + } + auto msg_id_map = queue_map->second.find(q_id); + if (msg_id_map == queue_map->second.end()) { + return StreamingStatus::QueueIdNotFound; + } + msg_id = msg_id_map->second; + return StreamingStatus::OK; +} + +void StreamingBarrierHelper::SetMsgIdByBarrierId(const ObjectID &q_id, + uint64_t barrier_id, uint64_t msg_id) { + std::lock_guard lock(global_barrier_mutex_); + global_barrier_map_[barrier_id][q_id] = msg_id; +} + +void StreamingBarrierHelper::ReleaseBarrierMapById(uint64_t barrier_id) { + std::lock_guard lock(global_barrier_mutex_); + global_barrier_map_.erase(barrier_id); +} + +void StreamingBarrierHelper::ReleaseAllBarrierMap() { + std::lock_guard lock(global_barrier_mutex_); + global_barrier_map_.clear(); +} + +void StreamingBarrierHelper::MapBarrierToCheckpoint(uint64_t barrier_id, + uint64_t checkpoint) { + std::lock_guard lock(barrier_map_checkpoint_mutex_); + barrier_checkpoint_map_[barrier_id] = checkpoint; +} + +StreamingStatus StreamingBarrierHelper::GetCheckpointIdByBarrierId( + uint64_t barrier_id, uint64_t &checkpoint_id) { + std::lock_guard lock(barrier_map_checkpoint_mutex_); + auto checkpoint_item = barrier_checkpoint_map_.find(barrier_id); + if (checkpoint_item == barrier_checkpoint_map_.end()) { + return StreamingStatus::NoSuchItem; + } + + checkpoint_id = checkpoint_item->second; + return StreamingStatus::OK; +} + +void StreamingBarrierHelper::ReleaseBarrierMapCheckpointByBarrierId( + const uint64_t barrier_id) { + std::lock_guard lock(barrier_map_checkpoint_mutex_); + auto it = barrier_checkpoint_map_.begin(); + while (it != barrier_checkpoint_map_.end()) { + if (it->first <= barrier_id) { + it = barrier_checkpoint_map_.erase(it); + } else { + it++; + } + } +} + +StreamingStatus StreamingBarrierHelper::GetBarrierIdByLastMessageId(const ObjectID &q_id, + uint64_t message_id, + uint64_t &barrier_id, + bool is_pop) { + std::lock_guard lock(message_id_map_barrier_mutex_); + auto message_item = global_reversed_barrier_map_.find(message_id); + if (message_item == global_reversed_barrier_map_.end()) { + return StreamingStatus::NoSuchItem; + } + + auto message_queue_item = message_item->second.find(q_id); + if (message_queue_item == message_item->second.end()) { + return StreamingStatus::QueueIdNotFound; + } + if (message_queue_item->second->empty()) { + STREAMING_LOG(WARNING) << "[Barrier] q id => " << q_id.Hex() << ", str num => " + << Util::Hexqid2str(q_id.Hex()) << ", message id " + << message_id; + return StreamingStatus::NoSuchItem; + } else { + barrier_id = message_queue_item->second->front(); + if (is_pop) { + message_queue_item->second->pop(); + } + } + return StreamingStatus::OK; +} + +void StreamingBarrierHelper::SetBarrierIdByLastMessageId(const ObjectID &q_id, + uint64_t message_id, + uint64_t barrier_id) { + std::lock_guard lock(message_id_map_barrier_mutex_); + + auto max_message_id_barrier = max_message_id_map_.find(q_id); + // remove finished barrier in different last message id + if (max_message_id_barrier != max_message_id_map_.end() && + max_message_id_barrier->second != message_id) { + if (global_reversed_barrier_map_.find(max_message_id_barrier->second) != + global_reversed_barrier_map_.end()) { + global_reversed_barrier_map_.erase(max_message_id_barrier->second); + } + } + + max_message_id_map_[q_id] = message_id; + auto message_item = global_reversed_barrier_map_.find(message_id); + if (message_item == global_reversed_barrier_map_.end()) { + BarrierIdQueue temp_queue = std::make_shared>(); + temp_queue->push(barrier_id); + global_reversed_barrier_map_[message_id][q_id] = temp_queue; + return; + } + auto message_queue_item = message_item->second.find(q_id); + if (message_queue_item != message_item->second.end()) { + message_queue_item->second->push(barrier_id); + } else { + BarrierIdQueue temp_queue = std::make_shared>(); + temp_queue->push(barrier_id); + global_reversed_barrier_map_[message_id][q_id] = temp_queue; + } +} + +void StreamingBarrierHelper::GetAllBarrier(std::vector &barrier_id_vec) { + std::transform( + global_barrier_map_.begin(), global_barrier_map_.end(), + std::back_inserter(barrier_id_vec), + [](std::unordered_map>::value_type + pair) { return pair.first; }); +} + +bool StreamingBarrierHelper::Contains(uint64_t barrier_id) { + return global_barrier_map_.find(barrier_id) != global_barrier_map_.end(); +} + +uint32_t StreamingBarrierHelper::GetBarrierMapSize() { + return global_barrier_map_.size(); +} + +void StreamingBarrierHelper::GetCurrentMaxCheckpointIdInQueue( + const ObjectID &q_id, uint64_t &checkpoint_id) const { + auto item = current_max_checkpoint_id_map_.find(q_id); + if (item != current_max_checkpoint_id_map_.end()) { + checkpoint_id = item->second; + } else { + checkpoint_id = 0; + } +} + +void StreamingBarrierHelper::SetCurrentMaxCheckpointIdInQueue( + const ObjectID &q_id, const uint64_t checkpoint_id) { + current_max_checkpoint_id_map_[q_id] = checkpoint_id; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/reliability/barrier_helper.h b/streaming/src/reliability/barrier_helper.h new file mode 100644 index 000000000..1ceacc47a --- /dev/null +++ b/streaming/src/reliability/barrier_helper.h @@ -0,0 +1,65 @@ +#pragma once +#include +#include + +#include "common/status.h" +#include "ray/common/id.h" + +namespace ray { +namespace streaming { +class StreamingBarrierHelper { + using BarrierIdQueue = std::shared_ptr>; + + private: + // Global barrier map set (global barrier id -> (channel id -> msg id)) + std::unordered_map> + global_barrier_map_; + + // Message id map to barrier id of each queue(continuous barriers hold same last message + // id) + // message id -> (queue id -> list(barrier id)). + // Thread unsafe to assign value in user's thread but collect it in loopforward thread. + std::unordered_map> + global_reversed_barrier_map_; + + std::unordered_map barrier_checkpoint_map_; + + std::unordered_map max_message_id_map_; + + // We assume default max checkpoint is 0. + std::unordered_map current_max_checkpoint_id_map_; + + std::mutex message_id_map_barrier_mutex_; + + std::mutex global_barrier_mutex_; + + std::mutex barrier_map_checkpoint_mutex_; + + public: + StreamingStatus GetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, + uint64_t &msg_id); + void SetMsgIdByBarrierId(const ObjectID &q_id, uint64_t barrier_id, uint64_t seq_id); + bool Contains(uint64_t barrier_id); + void ReleaseBarrierMapById(uint64_t barrier_id); + void ReleaseAllBarrierMap(); + void GetAllBarrier(std::vector &barrier_id_vec); + uint32_t GetBarrierMapSize(); + + void MapBarrierToCheckpoint(uint64_t barrier_id, uint64_t checkpoint); + StreamingStatus GetCheckpointIdByBarrierId(uint64_t barrier_id, + uint64_t &checkpoint_id); + void ReleaseBarrierMapCheckpointByBarrierId(const uint64_t barrier_id); + + StreamingStatus GetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, + uint64_t &barrier_id, bool is_pop = false); + void SetBarrierIdByLastMessageId(const ObjectID &q_id, uint64_t message_id, + uint64_t barrier_id); + + void GetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, + uint64_t &checkpoint_id) const; + + void SetCurrentMaxCheckpointIdInQueue(const ObjectID &q_id, + const uint64_t checkpoint_id); +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/reliability_helper.cc b/streaming/src/reliability_helper.cc new file mode 100644 index 000000000..9e4a083ab --- /dev/null +++ b/streaming/src/reliability_helper.cc @@ -0,0 +1,113 @@ +#include "reliability_helper.h" + +#include +namespace ray { +namespace streaming { + +std::shared_ptr ReliabilityHelperFactory::CreateReliabilityHelper( + const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) { + if (config.IsExactlyOnce()) { + return std::make_shared(config, barrier_helper, writer, reader); + } else { + return std::make_shared(config, barrier_helper, writer, reader); + } +} + +ReliabilityHelper::ReliabilityHelper(const StreamingConfig &config, + StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) + : config_(config), + barrier_helper_(barrier_helper), + writer_(writer), + reader_(reader) {} + +void ReliabilityHelper::Reload() {} + +bool ReliabilityHelper::StoreBundleMeta(ProducerChannelInfo &channel_info, + StreamingMessageBundlePtr &bundle_ptr, + bool is_replay) { + return false; +} + +bool ReliabilityHelper::FilterMessage(ProducerChannelInfo &channel_info, + const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id) { + bool is_filtered = false; + uint64_t &message_id = channel_info.current_message_id; + uint64_t last_msg_id = channel_info.message_last_commit_id; + + if (StreamingMessageType::Barrier == message_type) { + is_filtered = message_id < last_msg_id; + } else { + message_id++; + // Message last commit id is the last item in queue or restore from queue. + // It skip directly since message id is less or equal than current commit id. + is_filtered = message_id <= last_msg_id && !config_.IsAtLeastOnce(); + } + *write_message_id = message_id; + + return is_filtered; +} + +void ReliabilityHelper::CleanupCheckpoint(ProducerChannelInfo &channel_info, + uint64_t barrier_id) {} + +StreamingStatus ReliabilityHelper::InitChannelMerger(uint32_t timeout) { + return reader_->InitChannelMerger(timeout); +} + +StreamingStatus ReliabilityHelper::HandleNoValidItem(ConsumerChannelInfo &channel_info) { + STREAMING_LOG(DEBUG) << "[Reader] Queue " << channel_info.channel_id + << " get item timeout, resend notify " + << channel_info.current_message_id; + reader_->NotifyConsumedItem(channel_info, channel_info.current_message_id); + return StreamingStatus::OK; +} + +AtLeastOnceHelper::AtLeastOnceHelper(const StreamingConfig &config, + StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) + : ReliabilityHelper(config, barrier_helper, writer, reader) {} + +StreamingStatus AtLeastOnceHelper::InitChannelMerger(uint32_t timeout) { + // No merge in AT_LEAST_ONCE + return StreamingStatus::OK; +} + +StreamingStatus AtLeastOnceHelper::HandleNoValidItem(ConsumerChannelInfo &channel_info) { + if (current_sys_time_ms() - channel_info.resend_notify_timer > + StreamingConfig::RESEND_NOTIFY_MAX_INTERVAL) { + STREAMING_LOG(INFO) << "[Reader] Queue " << channel_info.channel_id + << " get item timeout, resend notify " + << channel_info.current_message_id; + reader_->NotifyConsumedItem(channel_info, channel_info.current_message_id); + channel_info.resend_notify_timer = current_sys_time_ms(); + } + return StreamingStatus::Invalid; +} + +ExactlyOnceHelper::ExactlyOnceHelper(const StreamingConfig &config, + StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader) + : ReliabilityHelper(config, barrier_helper, writer, reader) {} + +bool ExactlyOnceHelper::FilterMessage(ProducerChannelInfo &channel_info, + const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id) { + bool is_filtered = ReliabilityHelper::FilterMessage(channel_info, data, message_type, + write_message_id); + if (is_filtered && StreamingMessageType::Barrier == message_type && + StreamingRole::SOURCE == config_.GetStreamingRole()) { + *write_message_id = channel_info.message_last_commit_id; + // Do not skip source barrier when it's reconstructing from downstream. + is_filtered = false; + STREAMING_LOG(INFO) << "append barrier to buffer ring " << *write_message_id + << ", last commit id " << channel_info.message_last_commit_id; + } + return is_filtered; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/reliability_helper.h b/streaming/src/reliability_helper.h new file mode 100644 index 000000000..56089a085 --- /dev/null +++ b/streaming/src/reliability_helper.h @@ -0,0 +1,66 @@ +#pragma once +#include "channel/channel.h" +#include "data_reader.h" +#include "data_writer.h" +#include "reliability/barrier_helper.h" +#include "util/config.h" + +namespace ray { +namespace streaming { + +class ReliabilityHelper; +class DataWriter; +class DataReader; + +class ReliabilityHelperFactory { + public: + static std::shared_ptr CreateReliabilityHelper( + const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); +}; + +class ReliabilityHelper { + public: + ReliabilityHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); + virtual ~ReliabilityHelper() = default; + // Only exactly same need override this function. + virtual void Reload(); + // Store bundle meta or skip in replay mode. + virtual bool StoreBundleMeta(ProducerChannelInfo &channel_info, + StreamingMessageBundlePtr &bundle_ptr, + bool is_replay = false); + virtual void CleanupCheckpoint(ProducerChannelInfo &channel_info, uint64_t barrier_id); + // Filter message by different failover strategies. + virtual bool FilterMessage(ProducerChannelInfo &channel_info, const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id); + virtual StreamingStatus InitChannelMerger(uint32_t timeout); + virtual StreamingStatus HandleNoValidItem(ConsumerChannelInfo &channel_info); + + protected: + const StreamingConfig &config_; + StreamingBarrierHelper &barrier_helper_; + DataWriter *writer_; + DataReader *reader_; +}; + +class AtLeastOnceHelper : public ReliabilityHelper { + public: + AtLeastOnceHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); + StreamingStatus InitChannelMerger(uint32_t timeout) override; + StreamingStatus HandleNoValidItem(ConsumerChannelInfo &channel_info) override; +}; + +class ExactlyOnceHelper : public ReliabilityHelper { + public: + ExactlyOnceHelper(const StreamingConfig &config, StreamingBarrierHelper &barrier_helper, + DataWriter *writer, DataReader *reader); + bool FilterMessage(ProducerChannelInfo &channel_info, const uint8_t *data, + StreamingMessageType message_type, + uint64_t *write_message_id) override; + virtual ~ExactlyOnceHelper() = default; +}; +} // namespace streaming +} // namespace ray diff --git a/streaming/src/ring_buffer.cc b/streaming/src/ring_buffer/ring_buffer.cc similarity index 100% rename from streaming/src/ring_buffer.cc rename to streaming/src/ring_buffer/ring_buffer.cc diff --git a/streaming/src/ring_buffer.h b/streaming/src/ring_buffer/ring_buffer.h similarity index 100% rename from streaming/src/ring_buffer.h rename to streaming/src/ring_buffer/ring_buffer.h diff --git a/streaming/src/runtime_context.h b/streaming/src/runtime_context.h index 4b6f49ab8..a86ebbcd1 100644 --- a/streaming/src/runtime_context.h +++ b/streaming/src/runtime_context.h @@ -2,8 +2,8 @@ #include +#include "common/status.h" #include "config/streaming_config.h" -#include "status.h" namespace ray { namespace streaming { diff --git a/streaming/src/test/event_service_tests.cc b/streaming/src/test/event_service_tests.cc index 39e2793aa..5775cb247 100644 --- a/streaming/src/test/event_service_tests.cc +++ b/streaming/src/test/event_service_tests.cc @@ -23,7 +23,7 @@ TEST(EventServiceTest, Test1) { std::thread thread_empty([server, &mock_channel_info, &stop] { std::chrono::milliseconds MockTimer(20); while (!stop) { - Event event{&mock_channel_info, EventType::EmptyEvent, true}; + Event event(&mock_channel_info, EventType::EmptyEvent, true); server->Push(event); std::this_thread::sleep_for(MockTimer); } @@ -32,7 +32,7 @@ TEST(EventServiceTest, Test1) { std::thread thread_flow([server, &mock_channel_info, &stop] { std::chrono::milliseconds MockTimer(2); while (!stop) { - Event event{&mock_channel_info, EventType::FlowEvent, true}; + Event event(&mock_channel_info, EventType::FlowEvent, true); server->Push(event); std::this_thread::sleep_for(MockTimer); } @@ -41,7 +41,7 @@ TEST(EventServiceTest, Test1) { std::thread thread_user([server, &mock_channel_info, &stop] { std::chrono::milliseconds MockTimer(2); while (!stop) { - Event event{&mock_channel_info, EventType::UserEvent, true}; + Event event(&mock_channel_info, EventType::UserEvent, true); server->Push(event); std::this_thread::sleep_for(MockTimer); } @@ -76,9 +76,9 @@ TEST(EventServiceTest, remove_delete_channel_event) { mock_channel_info_vec.push_back(mock_channel_info2); for (auto &id : mock_channel_info_vec) { - Event empty_event{&id, EventType::EmptyEvent, true}; - Event user_event{&id, EventType::UserEvent, true}; - Event flow_event{&id, EventType::FlowEvent, true}; + Event empty_event(&id, EventType::EmptyEvent, true); + Event user_event(&id, EventType::UserEvent, true); + Event flow_event(&id, EventType::FlowEvent, true); server->Push(empty_event); server->Push(user_event); server->Push(flow_event); diff --git a/streaming/src/test/message_serialization_tests.cc b/streaming/src/test/message_serialization_tests.cc index b94064591..14dfc0232 100644 --- a/streaming/src/test/message_serialization_tests.cc +++ b/streaming/src/test/message_serialization_tests.cc @@ -80,7 +80,7 @@ TEST(StreamingSerializationTest, streaming_message_barrier_bundle_serialization_ auto s_item = s_message_list.back(); EXPECT_TRUE(s_item->ClassBytesSize() == m_item->ClassBytesSize()); EXPECT_TRUE(s_item->GetMessageType() == m_item->GetMessageType()); - EXPECT_TRUE(s_item->GetMessageSeqId() == m_item->GetMessageSeqId()); + EXPECT_TRUE(s_item->GetMessageId() == m_item->GetMessageId()); EXPECT_TRUE(s_item->GetDataSize() == m_item->GetDataSize()); EXPECT_TRUE( std::memcmp(s_item->RawData(), m_item->RawData(), m_item->GetDataSize()) == 0); diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 9bbcad803..4d3c652a8 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -67,27 +67,13 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { } private: - void TestWriteMessageToBufferRing(std::shared_ptr writer_client, - std::vector &q_list) { - // const uint8_t temp_data[] = {1, 2, 4, 5}; + void StreamingWriterExactlyOnceTest() { + StreamingConfig config; + StreamingWriterStrategyTest(config); - uint32_t i = 1; - while (i <= MESSAGE_BOUND_SIZE) { - for (auto &q_id : q_list) { - uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE); - uint8_t *data = new uint8_t[buffer_len]; - for (uint32_t j = 0; j < buffer_len; ++j) { - data[j] = j % 128; - } - - writer_client->WriteMessageToBufferRing(q_id, data, buffer_len, - StreamingMessageType::Message); - } - ++i; - } - - // Wait a while - std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + STREAMING_LOG(INFO) + << "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest"; + status_ = true; } void StreamingWriterStrategyTest(StreamingConfig &config) { @@ -111,6 +97,7 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { std::shared_ptr runtime_context(new RuntimeContext()); runtime_context->SetConfig(config); + // Create writer. std::shared_ptr streaming_writer_client(new DataWriter(runtime_context)); uint64_t queue_size = 10 * 1000 * 1000; std::vector channel_seq_id_vec(queue_ids_.size(), 0); @@ -119,22 +106,35 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { STREAMING_LOG(INFO) << "streaming_writer_client Init done"; streaming_writer_client->Run(); + + // Write some data. std::thread test_loop_thread( &StreamingQueueWriterTestSuite::TestWriteMessageToBufferRing, this, streaming_writer_client, std::ref(queue_ids_)); - // test_loop_thread.detach(); if (test_loop_thread.joinable()) { test_loop_thread.join(); } } - void StreamingWriterExactlyOnceTest() { - StreamingConfig config; - StreamingWriterStrategyTest(config); + void TestWriteMessageToBufferRing(std::shared_ptr writer_client, + std::vector &q_list) { + uint32_t i = 1; + while (i <= MESSAGE_BOUND_SIZE) { + for (auto &q_id : q_list) { + uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE); + uint8_t *data = new uint8_t[buffer_len]; + for (uint32_t j = 0; j < buffer_len; ++j) { + data[j] = j % 128; + } - STREAMING_LOG(INFO) - << "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest"; - status_ = true; + writer_client->WriteMessageToBufferRing(q_id, data, buffer_len, + StreamingMessageType::Message); + } + ++i; + } + STREAMING_LOG(INFO) << "Write data done."; + // Wait a while. + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); } }; @@ -180,7 +180,7 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { for (auto &q_id : queue_id_vec) { reader_client->NotifyConsumedItem((*offset_map)[q_id], - (*offset_map)[q_id].current_seq_id); + (*offset_map)[q_id].current_message_id); } // writer_client->ClearCheckpoint(msg->last_barrier_id); @@ -201,7 +201,7 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { recevied_message_cnt += message_list.size(); for (auto &item : message_list) { - uint64_t i = item->GetMessageSeqId(); + uint64_t i = item->GetMessageId(); uint32_t buff_len = i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE; if (i > MESSAGE_BOUND_SIZE) break; @@ -270,7 +270,7 @@ class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { } void GetQueueTest() { - // Sleep 2s, queue shoulde not exist when reader pull + // Sleep 2s, queue shoulde not exist when reader pull. std::this_thread::sleep_for(std::chrono::milliseconds(2000)); auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); ObjectID &queue_id = queue_ids_[0]; @@ -297,7 +297,7 @@ class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { } void PullPeerAsyncTest() { - // Sleep 2s, queue should not exist when reader pull + // Sleep 2s, queue should not exist when reader pull. std::this_thread::sleep_for(std::chrono::milliseconds(2000)); auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); ObjectID &queue_id = queue_ids_[0]; @@ -323,10 +323,8 @@ class StreamingQueueUpStreamTestSuite : public StreamingQueueTestSuite { uint8_t data[100]; memset(data, msg_id, 100); STREAMING_LOG(INFO) << "Writer User Push item msg_id: " << msg_id; - ASSERT_TRUE(queue - ->Push(msg_id /*seqid*/, data, 100, current_sys_time_ms(), msg_id, - msg_id, true) - .ok()); + ASSERT_TRUE( + queue->Push(data, 100, current_sys_time_ms(), msg_id, msg_id, true).ok()); queue->Send(); } diff --git a/streaming/src/test/mock_transfer_tests.cc b/streaming/src/test/mock_transfer_tests.cc index f6268ec81..62085ac73 100644 --- a/streaming/src/test/mock_transfer_tests.cc +++ b/streaming/src/test/mock_transfer_tests.cc @@ -10,7 +10,7 @@ TEST(StreamingMockTransfer, mock_produce_consume) { ObjectID channel_id = ObjectID::FromRandom(); ProducerChannelInfo producer_channel_info; producer_channel_info.channel_id = channel_id; - producer_channel_info.current_seq_id = 0; + producer_channel_info.current_message_id = 0; MockProducer producer(transfer_config, producer_channel_info); ConsumerChannelInfo consumer_channel_info; @@ -22,15 +22,12 @@ TEST(StreamingMockTransfer, mock_produce_consume) { producer.ProduceItemToChannel(data, 3); uint8_t *data_consumed; uint32_t data_size_consumed; - uint64_t data_seq_id; - consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1); + consumer.ConsumeItemFromChannel(data_consumed, data_size_consumed, -1); EXPECT_EQ(data_size_consumed, 3); - EXPECT_EQ(data_seq_id, 1); EXPECT_EQ(std::memcmp(data_consumed, data, 3), 0); consumer.NotifyChannelConsumed(1); - auto status = - consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1); + auto status = consumer.ConsumeItemFromChannel(data_consumed, data_size_consumed, -1); EXPECT_EQ(status, StreamingStatus::NoSuchItem); } @@ -52,8 +49,9 @@ class StreamingTransferTest : public ::testing::Test { std::vector channel_id_vec(queue_vec.size(), 0); std::vector queue_size_vec(queue_vec.size(), 10000); std::vector params(queue_vec.size()); + std::vector creation_status; writer->Init(queue_vec, params, channel_id_vec, queue_size_vec); - reader->Init(queue_vec, params, channel_id_vec, queue_size_vec, -1); + reader->Init(queue_vec, params, channel_id_vec, creation_status, -1); } void DestroyTransfer() { writer.reset(); @@ -152,18 +150,21 @@ TEST_F(StreamingTransferTest, flow_control_test) { reader->GetOffsetInfo(reader_offset_info); uint32_t writer_step = writer_runtime_context->GetConfig().GetWriterConsumedStep(); uint32_t reader_step = reader_runtime_context->GetConfig().GetReaderConsumedStep(); - uint64_t &writer_current_seq_id = (*writer_offset_info)[queue_vec[0]].current_seq_id; - uint64_t &writer_current_message_id = + uint64_t &writer_current_msg_id = (*writer_offset_info)[queue_vec[0]].current_message_id; - uint64_t &reader_target_seq_id = - (*reader_offset_info)[queue_vec[0]].queue_info.target_seq_id; - while (writer_current_seq_id < writer_step) { - STREAMING_LOG(INFO) << "Writer currrent seq id " << writer_current_seq_id - << " message " << writer_current_message_id << " consumer step " - << writer_step; + uint64_t &writer_last_commit_id = + (*writer_offset_info)[queue_vec[0]].message_last_commit_id; + uint64_t &writer_target_msg_id = + (*writer_offset_info)[queue_vec[0]].queue_info.target_message_id; + uint64_t &reader_target_msg_id = + (*reader_offset_info)[queue_vec[0]].queue_info.target_message_id; + do { std::this_thread::sleep_for( std::chrono::milliseconds(StreamingConfig::TIME_WAIT_UINT)); - } + STREAMING_LOG(INFO) << "Writer currrent msg id " << writer_current_msg_id + << ", writer target_msg_id=" << writer_target_msg_id + << ", consumer step " << writer_step; + } while (writer_current_msg_id < writer_step); std::list read_message_list; while (read_message_list.size() < num) { @@ -173,8 +174,8 @@ TEST_F(StreamingTransferTest, flow_control_test) { auto &message_list = bundle_ptr->GetMessageList(); std::copy(message_list.begin(), message_list.end(), std::back_inserter(read_message_list)); - ASSERT_GE(writer_step, writer_current_seq_id - msg->seq_id); - ASSERT_GE(msg->seq_id + reader_step, reader_target_seq_id); + ASSERT_GE(writer_step, writer_last_commit_id - msg->meta->GetLastMessageId()); + ASSERT_GE(msg->meta->GetLastMessageId() + reader_step, reader_target_msg_id); } int index = 0; for (auto &message : read_message_list) { diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index 25c4c0061..fc9e81d18 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -87,7 +87,11 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { args.emplace_back(new TaskArgByValue(std::make_shared( msg.ToBytes(), nullptr, std::vector(), true))); std::unordered_map resources; +<<<<<<< HEAD + TaskOptions options(0, resources); +======= TaskOptions options{"", 0, resources}; +>>>>>>> 6a78ba9752dc7f17b0e4b7423898c0facf777d3d std::vector return_ids; RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython("", "", "init", "")}; @@ -103,7 +107,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector(), true))); std::unordered_map resources; - TaskOptions options{"", 0, resources}; + TaskOptions options(0, resources); std::vector return_ids; RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "", test, "execute_test", "")}; @@ -119,7 +123,11 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { args.emplace_back(new TaskArgByValue( std::make_shared(buffer, nullptr, std::vector(), true))); std::unordered_map resources; +<<<<<<< HEAD + TaskOptions options(1, resources); +======= TaskOptions options{"", 1, resources}; +>>>>>>> 6a78ba9752dc7f17b0e4b7423898c0facf777d3d std::vector return_ids; RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "", "", "check_current_test_status", "")}; diff --git a/streaming/src/test/run_streaming_queue_test.sh b/streaming/src/test/run_streaming_queue_test.sh index c53e6295f..752c95b2b 100755 --- a/streaming/src/test/run_streaming_queue_test.sh +++ b/streaming/src/test/run_streaming_queue_test.sh @@ -44,15 +44,22 @@ if [ ! -d "$RAY_ROOT/python" ]; then exit 1 fi -REDIS_MODULE="./bazel-bin/libray_redis_module.so" -REDIS_SERVER_EXEC="./bazel-bin/external/com_github_antirez_redis/redis-server" -STORE_EXEC="./bazel-bin/plasma_store_server" -REDIS_CLIENT_EXEC="./bazel-bin/redis-cli" -RAYLET_EXEC="./bazel-bin/raylet" -STREAMING_TEST_WORKER_EXEC="./bazel-bin/streaming/streaming_test_worker" -GCS_SERVER_EXEC="./bazel-bin/gcs_server" +REDIS_MODULE="$RAY_ROOT/bazel-bin/libray_redis_module.so" +REDIS_SERVER_EXEC="$RAY_ROOT/bazel-bin/external/com_github_antirez_redis/redis-server" +STORE_EXEC="$RAY_ROOT/bazel-bin/plasma_store_server" +REDIS_CLIENT_EXEC="$RAY_ROOT/bazel-bin/redis-cli" +RAYLET_EXEC="$RAY_ROOT/bazel-bin/raylet" +STREAMING_TEST_WORKER_EXEC="$RAY_ROOT/bazel-bin/streaming/streaming_test_worker" +GCS_SERVER_EXEC="$RAY_ROOT/bazel-bin/gcs_server" + +# clear env +pgrep "plasma|DefaultDriver|DefaultWorker|AppStarter|redis|http_server|job_agent" | xargs kill -9 &> /dev/null -# Allow cleanup commands to fail. # Run tests. -./bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC "$RAYLET_PORT" $STREAMING_TEST_WORKER_EXEC $GCS_SERVER_EXEC $REDIS_SERVER_EXEC $REDIS_MODULE $REDIS_CLIENT_EXEC + +# to run specific test, add --gtest_filter, below is an example +#$RAY_ROOT/bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC $GCS_SERVER_EXEC $REDIS_SERVER_EXEC $REDIS_MODULE $REDIS_CLIENT_EXEC --gtest_filter=StreamingTest/StreamingWriterTest.streaming_writer_exactly_once_test/0 + +# run all tests +"$RAY_ROOT"/bazel-bin/streaming/streaming_queue_tests "$STORE_EXEC" "$RAYLET_EXEC" "$RAYLET_PORT" "$STREAMING_TEST_WORKER_EXEC" "$GCS_SERVER_EXEC" "$REDIS_SERVER_EXEC" "$REDIS_MODULE" "$REDIS_CLIENT_EXEC" sleep 1s diff --git a/streaming/src/test/streaming_queue_tests.cc b/streaming/src/test/streaming_queue_tests.cc index c2c678315..f45e2a45c 100644 --- a/streaming/src/test/streaming_queue_tests.cc +++ b/streaming/src/test/streaming_queue_tests.cc @@ -66,7 +66,6 @@ INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingExactlySameTest, } // namespace ray int main(int argc, char **argv) { - // set_streaming_log_config("streaming_writer_test", StreamingLogLevel::INFO, 0); ::testing::InitGoogleTest(&argc, argv); RAY_CHECK(argc == 9); ray::TEST_STORE_EXEC_PATH = std::string(argv[1]); diff --git a/streaming/src/util/config.cc b/streaming/src/util/config.cc new file mode 100644 index 000000000..b5c3e320a --- /dev/null +++ b/streaming/src/util/config.cc @@ -0,0 +1,20 @@ +#include "config.h" +namespace ray { +namespace streaming { + +boost::any &Config::Get(ConfigEnum key) const { + auto item = config_map_.find(key); + STREAMING_CHECK(item != config_map_.end()); + return item->second; +} + +boost::any Config::Get(ConfigEnum key, boost::any default_value) const { + auto item = config_map_.find(key); + if (item == config_map_.end()) { + return default_value; + } + return item->second; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/config.h b/streaming/src/util/config.h new file mode 100644 index 000000000..56c6af81f --- /dev/null +++ b/streaming/src/util/config.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include + +#include "streaming_logging.h" + +namespace ray { +namespace streaming { +enum class ConfigEnum : uint32_t { + QUEUE_ID_VECTOR = 0, + MIN = QUEUE_ID_VECTOR, + MAX = QUEUE_ID_VECTOR +}; +} +} // namespace ray + +namespace std { +template <> +struct hash<::ray::streaming::ConfigEnum> { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; + +template <> +struct hash { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; +} // namespace std + +namespace ray { +namespace streaming { + +class Config { + public: + template + inline void Set(ConfigEnum key, const ValueType &any) { + config_map_.emplace(key, any); + } + + template + inline void Set(ConfigEnum key, ValueType &&any) { + config_map_.emplace(key, any); + } + + template + inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) { + auto item = config_map_.find(key); + if (item != config_map_.end()) { + return item->second; + } + Set(key, any); + return any; + } + + boost::any &Get(ConfigEnum key) const; + boost::any Get(ConfigEnum key, boost::any default_value) const; + + inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline double GetDouble(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline bool GetBool(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline std::string GetString(ConfigEnum key) { + return boost::any_cast(Get(key)); + } + + virtual ~Config() = default; + + protected: + mutable std::unordered_map config_map_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/streaming_util.cc b/streaming/src/util/streaming_util.cc index 95038f2a9..4f2a13535 100644 --- a/streaming/src/util/streaming_util.cc +++ b/streaming/src/util/streaming_util.cc @@ -3,21 +3,6 @@ #include namespace ray { namespace streaming { - -boost::any &Config::Get(ConfigEnum key) const { - auto item = config_map_.find(key); - STREAMING_CHECK(item != config_map_.end()); - return item->second; -} - -boost::any Config::Get(ConfigEnum key, boost::any default_value) const { - auto item = config_map_.find(key); - if (item == config_map_.end()) { - return default_value; - } - return item->second; -} - std::string Util::Byte2hex(const uint8_t *data, uint32_t data_size) { constexpr char hex[] = "0123456789abcdef"; std::string result; diff --git a/streaming/src/util/streaming_util.h b/streaming/src/util/streaming_util.h index 8f28dc3ec..bd7035354 100644 --- a/streaming/src/util/streaming_util.h +++ b/streaming/src/util/streaming_util.h @@ -1,97 +1,84 @@ #pragma once #include +#include #include #include +#include "ray/common/id.h" #include "util/streaming_logging.h" namespace ray { namespace streaming { -enum class ConfigEnum : uint32_t { - QUEUE_ID_VECTOR = 0, - RECONSTRUCT_RETRY_TIMES, - RECONSTRUCT_TIMEOUT_PER_MB, - CURRENT_DRIVER_ID, - /// For direct call - CORE_WORKER, - SYNC_FUNCTION, - ASYNC_FUNCTION, - TRANSFER_MIN = QUEUE_ID_VECTOR, - TRANSFER_MAX = ASYNC_FUNCTION -}; -} // namespace streaming -} // namespace ray - -namespace std { -template <> -struct hash<::ray::streaming::ConfigEnum> { - size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { - return static_cast(config_enum_key); - } -}; - -template <> -struct hash { - size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { - return static_cast(config_enum_key); - } -}; -} // namespace std - -namespace ray { -namespace streaming { - -class Config { - public: - template - inline void Set(ConfigEnum key, const ValueType &any) { - config_map_.emplace(key, any); - } - - template - inline void Set(ConfigEnum key, ValueType &&any) { - config_map_.emplace(key, any); - } - - template - inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) { - auto item = config_map_.find(key); - if (item != config_map_.end()) { - return item->second; - } - Set(key, any); - return any; - } - - boost::any &Get(ConfigEnum key) const; - - boost::any Get(ConfigEnum key, boost::any default_value) const; - - inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline double GetDouble(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline bool GetBool(ConfigEnum key) { return boost::any_cast(Get(key)); } - - inline std::string GetString(ConfigEnum key) { - return boost::any_cast(Get(key)); - } - - virtual ~Config() = default; - - protected: - mutable std::unordered_map config_map_; -}; - class Util { public: static std::string Byte2hex(const uint8_t *data, uint32_t data_size); static std::string Hexqid2str(const std::string &q_id_hex); + + template + static std::string join(const T &v, const std::string &delimiter, + const std::string &prefix = "", + const std::string &suffix = "") { + std::stringstream ss; + size_t i = 0; + ss << prefix; + for (const auto &elem : v) { + if (i != 0) { + ss << delimiter; + } + ss << elem; + i++; + } + ss << suffix; + return ss.str(); + } + + template + static std::string join(InputIterator first, InputIterator last, + const std::string &delim, const std::string &arround = "") { + std::string a = arround; + while (first != last) { + a += std::to_string(*first); + first++; + if (first != last) a += delim; + } + a += arround; + return a; + } + + template + static std::string join(InputIterator first, InputIterator last, + std::function func, + const std::string &delim, const std::string &arround = "") { + std::string a = arround; + while (first != last) { + a += func(first); + first++; + if (first != last) a += delim; + } + a += arround; + return a; + } }; + +class AutoSpinLock { + public: + explicit AutoSpinLock(std::atomic_flag &lock) : lock_(lock) { + while (lock_.test_and_set(std::memory_order_acquire)) + ; + } + ~AutoSpinLock() { unlock(); } + void unlock() { lock_.clear(std::memory_order_release); } + + private: + std::atomic_flag &lock_; +}; + +inline void ConvertToValidQueueId(const ObjectID &queue_id) { + auto addr = const_cast(&queue_id); + *(reinterpret_cast(addr)) = 0; +} } // namespace streaming } // namespace ray