diff --git a/.travis.yml b/.travis.yml index 9cc64aabd..a58a78acd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -149,10 +149,10 @@ matrix: # - python -m pytest python/ray/dataframe/test/test_groupby.py # ray tune tests - # - python python/ray/tune/test/dependency_test.py - # - python -m pytest python/ray/tune/test/trial_runner_test.py + - python python/ray/tune/test/dependency_test.py + - python -m pytest python/ray/tune/test/trial_runner_test.py - python -m pytest python/ray/tune/test/trial_scheduler_test.py - # - python -m pytest python/ray/tune/test/tune_server_test.py + - python -m pytest python/ray/tune/test/tune_server_test.py # ray rllib tests - python -m pytest python/ray/rllib/test/test_catalog.py diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 3f07e7bb5..b7882c5f9 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -368,11 +368,17 @@ class TrialRunner(object): def _update_avail_resources(self): clients = ray.global_state.client_table() - local_schedulers = [ - entry for client in clients.values() for entry in client if - (entry['ClientType'] == 'local_scheduler' and not entry['Deleted']) - ] - num_cpus = sum(ls['CPU'] for ls in local_schedulers) - num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers) + if ray.worker.global_worker.use_raylet: + # TODO(rliaw): Remove once raylet flag is swapped + num_cpus = sum(cl['Resources']['CPU'] for cl in clients) + num_gpus = sum(cl['Resources'].get('GPU', 0) for cl in clients) + else: + local_schedulers = [ + entry for client in clients.values() for entry in client + if (entry['ClientType'] == 'local_scheduler' + and not entry['Deleted']) + ] + num_cpus = sum(ls['CPU'] for ls in local_schedulers) + num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers) self._avail_resources = Resources(int(num_cpus), int(num_gpus)) self._resources_initialized = True diff --git a/python/ray/worker.py b/python/ray/worker.py index 84e88e18d..266d3e48e 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1862,10 +1862,10 @@ def print_error_messages_raylet(worker): try: for msg in worker.error_message_pubsub_client.listen(): - gcs_entry = state.GcsTableEntry.GetRootAsGcsTableEntry( + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( msg["data"], 0) assert gcs_entry.EntriesLength() == 1 - error_data = state.ErrorTableData.GetRootAsErrorTableData( + error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( gcs_entry.Entries(0), 0) NIL_JOB_ID = 20 * b"\x00" job_id = error_data.JobId()