From f9f2bfa77861539e467802185d665ae79f5ce25c Mon Sep 17 00:00:00 2001 From: Lingxuan Zuo Date: Mon, 25 Jan 2021 20:32:08 +0800 Subject: [PATCH] [Metric] Fix crashed when register metric view in multithread (#13485) * Fix crashed when register metric view in multithread * fix comments * fix --- src/ray/stats/metric.cc | 27 +++++++++++++++++---------- src/ray/stats/metric.h | 3 +++ src/ray/stats/stats_test.cc | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/ray/stats/metric.cc b/src/ray/stats/metric.cc index 4a475a338..d4b253428 100644 --- a/src/ray/stats/metric.cc +++ b/src/ray/stats/metric.cc @@ -22,6 +22,8 @@ namespace ray { namespace stats { +absl::Mutex Metric::registration_mutex_; + static void RegisterAsView(opencensus::stats::ViewDescriptor view_descriptor, const std::vector &keys) { // Register global keys. @@ -85,19 +87,24 @@ void Metric::Record(double value, const TagsType &tags) { return; } + // NOTE(lingxuan.zlx): Double check for recording performance while + // processing in multithread and avoid race since metrics may invoke + // record in different threads or code pathes. if (measure_ == nullptr) { - // Measure could be registered before, so we try to get it first. - MeasureDouble registered_measure = - opencensus::stats::MeasureRegistry::GetMeasureDoubleByName(name_); + absl::MutexLock lock(®istration_mutex_); + if (measure_ == nullptr) { + // Measure could be registered before, so we try to get it first. + MeasureDouble registered_measure = + opencensus::stats::MeasureRegistry::GetMeasureDoubleByName(name_); - if (registered_measure.IsValid()) { - measure_.reset(new MeasureDouble(registered_measure)); - } else { - measure_.reset( - new MeasureDouble(MeasureDouble::Register(name_, description_, unit_))); + if (registered_measure.IsValid()) { + measure_.reset(new MeasureDouble(registered_measure)); + } else { + measure_.reset( + new MeasureDouble(MeasureDouble::Register(name_, description_, unit_))); + } + RegisterView(); } - - RegisterView(); } // Do record. diff --git a/src/ray/stats/metric.h b/src/ray/stats/metric.h index 06e8534c4..dac50bc2d 100644 --- a/src/ray/stats/metric.h +++ b/src/ray/stats/metric.h @@ -129,6 +129,9 @@ class Metric { std::vector tag_keys_; std::unique_ptr> measure_; + // For making sure thread-safe to all of metric registrations. + static absl::Mutex registration_mutex_; + }; // class Metric class Gauge : public Metric { diff --git a/src/ray/stats/stats_test.cc b/src/ray/stats/stats_test.cc index 21e162723..38f795282 100644 --- a/src/ray/stats/stats_test.cc +++ b/src/ray/stats/stats_test.cc @@ -116,6 +116,38 @@ TEST_F(StatsTest, InitializationTest) { ASSERT_TRUE(new_first_tag.second == test_tag_value_that_shouldnt_be_applied); } +TEST(Metric, MultiThreadMetricRegisterViewTest) { + ray::stats::Shutdown(); + std::shared_ptr exporter( + new stats::StdoutExporterClient()); + ray::stats::Init({}, MetricsAgentPort, exporter); + std::vector threads; + const stats::TagKeyType tag1 = stats::TagKeyType::Register("k1"); + const stats::TagKeyType tag2 = stats::TagKeyType::Register("k2"); + for (int index = 0; index < 10; ++index) { + threads.emplace_back([tag1, tag2, index]() { + for (int i = 0; i < 100; i++) { + stats::Count random_counter( + "ray.random.counter" + std::to_string(index) + std::to_string(i), "", "", + {tag1, tag2}); + random_counter.Record(i); + stats::Gauge random_gauge( + "ray.random.gauge" + std::to_string(index) + std::to_string(i), "", "", + {tag1, tag2}); + random_gauge.Record(i); + stats::Sum random_sum( + "ray.random.sum" + std::to_string(index) + std::to_string(i), "", "", + {tag1, tag2}); + random_sum.Record(i); + } + }); + } + for (auto &thread : threads) { + thread.join(); + } + ray::stats::Shutdown(); +} + TEST_F(StatsTest, MultiThreadedInitializationTest) { // Make sure stats module is thread-safe. // Shutdown the stats module first.