From 4bb5b6bd5b98e29c462475a2b671b1de7e889f4c Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 24 Dec 2017 12:25:13 -0800 Subject: [PATCH] [rllib] A3C Configurations (#1370) * initial introduction of a3c configs * fix sample batch * flake but need to check save * save,resotre * fix * pickles * entropy * fix * moving ppo * results * jenkins --- python/ray/rllib/a3c/a3c.py | 123 +++++++++--------- .../a3c/{runner.py => base_evaluator.py} | 49 ++++--- python/ray/rllib/a3c/shared_model.py | 7 +- python/ray/rllib/a3c/shared_model_lstm.py | 7 +- python/ray/rllib/a3c/shared_torch_policy.py | 17 ++- python/ray/rllib/a3c/tfpolicy.py | 13 +- python/ray/rllib/a3c/torchpolicy.py | 4 +- python/ray/rllib/optimizers/async.py | 3 +- python/ray/rllib/ppo/ppo.py | 9 +- python/ray/rllib/ppo/rollout.py | 4 +- python/ray/rllib/ppo/runner.py | 4 +- .../tuned_examples/pong-a3c-pytorch.yaml | 7 +- python/ray/rllib/tuned_examples/pong-a3c.yaml | 10 +- python/ray/rllib/utils/process_rollout.py | 15 ++- python/ray/rllib/utils/sampler.py | 5 +- 15 files changed, 164 insertions(+), 113 deletions(-) rename python/ray/rllib/a3c/{runner.py => base_evaluator.py} (64%) diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 2e2f5d840..3b73f7654 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -8,85 +8,83 @@ import os import ray from ray.rllib.agent import Agent -from ray.rllib.envs import create_and_wrap -from ray.rllib.a3c.runner import RemoteA3CEvaluator -from ray.rllib.a3c.common import get_policy_cls -from ray.rllib.utils.filter import get_filter +from ray.rllib.optimizers import AsyncOptimizer +from ray.rllib.a3c.base_evaluator import A3CEvaluator, RemoteA3CEvaluator from ray.tune.result import TrainingResult DEFAULT_CONFIG = { + # Number of workers (excluding master) "num_workers": 4, - "num_batches_per_iteration": 100, - # Size of rollout batch "batch_size": 10, - "use_lstm": True, + # Use LSTM model - only applicable for image states + "use_lstm": False, + # Use PyTorch as backend - no LSTM support "use_pytorch": False, # Which observation filter to apply to the observation "observation_filter": "NoFilter", # Which reward filter to apply to the reward "reward_filter": "NoFilter", - - "model": {"grayscale": True, - "zero_mean": False, - "dim": 42, - "channel_major": False} + # Discount factor of MDP + "gamma": 0.99, + # GAE(gamma) parameter + "lambda": 1.0, + # Max global norm for each gradient calculated by worker + "grad_clip": 40.0, + # Learning rate + "lr": 0.0001, + # Value Function Loss coefficient + "vf_loss_coeff": 0.5, + # Entropy coefficient + "entropy_coeff": -0.01, + # Preprocessing for environment + "preprocessing": { + # (Image statespace) - Converts image to Channels = 1 + "grayscale": True, + # (Image statespace) - Each pixel + "zero_mean": False, + # (Image statespace) - Converts image to (dim, dim, C) + "dim": 80, + # (Image statespace) - Converts image shape to (C, dim, dim) + "channel_major": False + }, + # Configuration for model specification + "model": {}, + # Arguments to pass to the rllib optimizer + "optimizer": { + # Number of gradients applied for each `train` step + "grads_per_step": 100, + }, } class A3CAgent(Agent): _agent_name = "A3C" _default_config = DEFAULT_CONFIG + _allow_unknown_subkeys = ["model", "optimizer"] def _init(self): - self.env = create_and_wrap(self.env_creator, self.config["model"]) - policy_cls = get_policy_cls(self.config) - self.policy = policy_cls( - self.env.observation_space.shape, self.env.action_space) - self.obs_filter = get_filter( - self.config["observation_filter"], - self.env.observation_space.shape) - self.rew_filter = get_filter(self.config["reward_filter"], ()) - self.agents = [ + self.local_evaluator = A3CEvaluator( + self.env_creator, self.config, self.logdir, start_sampler=False) + self.remote_evaluators = [ RemoteA3CEvaluator.remote( self.env_creator, self.config, self.logdir) for i in range(self.config["num_workers"])] - self.parameters = self.policy.get_weights() + self.optimizer = AsyncOptimizer( + self.config["optimizer"], self.local_evaluator, + self.remote_evaluators) def _train(self): - remote_params = ray.put(self.parameters) - ray.get([agent.set_weights.remote(remote_params) - for agent in self.agents]) - - gradient_list = {agent.compute_gradient.remote(): agent - for agent in self.agents} - max_batches = self.config["num_batches_per_iteration"] - batches_so_far = len(gradient_list) - while gradient_list: - [done_id], _ = ray.wait(list(gradient_list)) - gradient, info = ray.get(done_id) - agent = gradient_list.pop(done_id) - self.obs_filter.update(info["obs_filter"]) - self.rew_filter.update(info["rew_filter"]) - self.policy.apply_gradients(gradient) - self.parameters = self.policy.get_weights() - - if batches_so_far < max_batches: - batches_so_far += 1 - agent.update_filters.remote( - obs_filter=self.obs_filter, - rew_filter=self.rew_filter) - agent.set_weights.remote(self.parameters) - gradient_list[agent.compute_gradient.remote()] = agent - res = self._fetch_metrics_from_workers() + self.optimizer.step() + res = self._fetch_metrics_from_remote_evaluators() return res - def _fetch_metrics_from_workers(self): + def _fetch_metrics_from_remote_evaluators(self): episode_rewards = [] episode_lengths = [] - metric_lists = [ - a.get_completed_rollout_metrics.remote() for a in self.agents] + metric_lists = [a.get_completed_rollout_metrics.remote() + for a in self.remote_evaluators] for metrics in metric_lists: for episode in ray.get(metrics): episode_lengths.append(episode.episode_length) @@ -106,22 +104,25 @@ class A3CAgent(Agent): return result def _save(self): - # TODO(rliaw): extend to also support saving worker state? checkpoint_path = os.path.join( self.logdir, "checkpoint-{}".format(self.iteration)) - objects = [self.parameters, self.obs_filter, self.rew_filter] - pickle.dump(objects, open(checkpoint_path, "wb")) + # self.saver.save + agent_state = ray.get( + [a.save.remote() for a in self.remote_evaluators]) + extra_data = { + "remote_state": agent_state, + "local_state": self.local_evaluator.save()} + pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): - objects = pickle.load(open(checkpoint_path, "rb")) - self.parameters = objects[0] - self.obs_filter = objects[1] - self.rew_filter = objects[2] - self.policy.set_weights(self.parameters) + extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) + ray.get( + [a.restore.remote(o) for a, o in zip( + self.remote_evaluators, extra_data["remote_state"])]) + self.local_evaluator.restore(extra_data["local_state"]) - # TODO(rliaw): augment to support LSTM def compute_action(self, observation): - obs = self.obs_filter(observation, update=False) - action, info = self.policy.compute(obs) + obs = self.local_evaluator.obs_filter(observation, update=False) + action, info = self.local_evaluator.policy.compute(obs) return action diff --git a/python/ray/rllib/a3c/runner.py b/python/ray/rllib/a3c/base_evaluator.py similarity index 64% rename from python/ray/rllib/a3c/runner.py rename to python/ray/rllib/a3c/base_evaluator.py index a380be70a..4872fec51 100644 --- a/python/ray/rllib/a3c/runner.py +++ b/python/ray/rllib/a3c/base_evaluator.py @@ -2,6 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import pickle + import ray from ray.rllib.envs import create_and_wrap from ray.rllib.optimizers import Evaluator @@ -23,16 +25,22 @@ class A3CEvaluator(Evaluator): rollouts. logdir: Directory for logging. """ - def __init__(self, env_creator, config, logdir): - self.env = env = create_and_wrap(env_creator, config["model"]) + def __init__(self, env_creator, config, logdir, start_sampler=True): + self.env = env = create_and_wrap(env_creator, config["preprocessing"]) policy_cls = get_policy_cls(config) # TODO(rliaw): should change this to be just env.observation_space - self.policy = policy_cls(env.observation_space.shape, env.action_space) - obs_filter = get_filter( + self.policy = policy_cls( + env.observation_space.shape, env.action_space, config) + self.config = config + + # Technically not needed when not remote + self.obs_filter = get_filter( config["observation_filter"], env.observation_space.shape) self.rew_filter = get_filter(config["reward_filter"], ()) - self.sampler = AsyncSampler(env, self.policy, obs_filter, + self.sampler = AsyncSampler(env, self.policy, self.obs_filter, config["batch_size"]) + if start_sampler and self.sampler.async: + self.sampler.start() self.logdir = logdir def sample(self): @@ -40,7 +48,10 @@ class A3CEvaluator(Evaluator): Returns: trajectory (PartialRollout): Experience Samples from evaluator""" rollout = self.sampler.get_data() - return rollout + samples = process_rollout( + rollout, self.rew_filter, gamma=self.config["gamma"], + lambda_=self.config["lambda"], use_gae=True) + return samples def get_completed_rollout_metrics(self): """Returns metrics on previously completed rollouts. @@ -49,20 +60,16 @@ class A3CEvaluator(Evaluator): """ return self.sampler.get_metrics() - def compute_gradient(self): - rollout = self.sampler.get_data() - obs_filter = self.sampler.get_obs_filter(flush=True) + def compute_gradients(self, samples): + gradient, info = self.policy.compute_gradients(samples) + return gradient - traj = process_rollout( - rollout, self.rew_filter, gamma=0.99, lambda_=1.0, use_gae=True) - gradient, info = self.policy.compute_gradients(traj) - info["obs_filter"] = obs_filter - info["rew_filter"] = self.rew_filter - return gradient, info - - def apply_gradient(self, grads): + def apply_gradients(self, grads): self.policy.apply_gradients(grads) + def get_weights(self): + return self.policy.get_weights() + def set_weights(self, params): self.policy.set_weights(params) @@ -73,5 +80,13 @@ class A3CEvaluator(Evaluator): if obs_filter: self.sampler.update_obs_filter(obs_filter) + def save(self): + weights = self.get_weights() + return pickle.dumps({"weights": weights}) + + def restore(self, objs): + objs = pickle.loads(objs) + self.set_weights(objs["weights"]) + RemoteA3CEvaluator = ray.remote(A3CEvaluator) diff --git a/python/ray/rllib/a3c/shared_model.py b/python/ray/rllib/a3c/shared_model.py index 75719215b..bbc70a7aa 100644 --- a/python/ray/rllib/a3c/shared_model.py +++ b/python/ray/rllib/a3c/shared_model.py @@ -13,13 +13,14 @@ class SharedModel(TFPolicy): other_output = ["vf_preds"] is_recurrent = False - def __init__(self, ob_space, ac_space, **kwargs): - super(SharedModel, self).__init__(ob_space, ac_space, **kwargs) + def __init__(self, ob_space, ac_space, config, **kwargs): + super(SharedModel, self).__init__(ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space) - self._model = ModelCatalog.get_model(self.x, self.logit_dim) + self._model = ModelCatalog.get_model( + self.x, self.logit_dim, self.config["model"]) self.logits = self._model.outputs self.curr_dist = dist_class(self.logits) # with tf.variable_scope("vf"): diff --git a/python/ray/rllib/a3c/shared_model_lstm.py b/python/ray/rllib/a3c/shared_model_lstm.py index 0abe0853f..dc99eac61 100644 --- a/python/ray/rllib/a3c/shared_model_lstm.py +++ b/python/ray/rllib/a3c/shared_model_lstm.py @@ -13,7 +13,7 @@ class SharedModelLSTM(TFPolicy): """ Attributes: other_output (list): Other than `action`, the other return values from - `compute_gradient`. + `compute_gradients`. is_recurrent (bool): True if is a recurrent network (requires features to be tracked). """ @@ -21,8 +21,9 @@ class SharedModelLSTM(TFPolicy): other_output = ["vf_preds", "features"] is_recurrent = True - def __init__(self, ob_space, ac_space, **kwargs): - super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs) + def __init__(self, ob_space, ac_space, config, **kwargs): + super(SharedModelLSTM, self).__init__( + ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) diff --git a/python/ray/rllib/a3c/shared_torch_policy.py b/python/ray/rllib/a3c/shared_torch_policy.py index 35df34837..83dce1a60 100644 --- a/python/ray/rllib/a3c/shared_torch_policy.py +++ b/python/ray/rllib/a3c/shared_torch_policy.py @@ -17,14 +17,16 @@ class SharedTorchPolicy(TorchPolicy): other_output = ["vf_preds"] is_recurrent = False - def __init__(self, ob_space, ac_space, **kwargs): + def __init__(self, ob_space, ac_space, config, **kwargs): super(SharedTorchPolicy, self).__init__( - ob_space, ac_space, **kwargs) + ob_space, ac_space, config, **kwargs) def _setup_graph(self, ob_space, ac_space): _, self.logit_dim = ModelCatalog.get_action_dist(ac_space) - self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim) - self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001) + self._model = ModelCatalog.get_torch_model( + ob_space, self.logit_dim, self.config["model"]) + self.optimizer = torch.optim.Adam( + self._model.parameters(), lr=self.config["lr"]) def compute(self, ob, *args): """Should take in a SINGLE ob""" @@ -68,6 +70,9 @@ class SharedTorchPolicy(TorchPolicy): value_err = 0.5 * (values - rs).pow(2).sum() self.optimizer.zero_grad() - overall_err = 0.5 * value_err + pi_err - entropy * 0.01 + overall_err = (pi_err + + value_err * self.config["vf_loss_coeff"] + + entropy * self.config["entropy_coeff"]) overall_err.backward() - torch.nn.utils.clip_grad_norm(self._model.parameters(), 40) + torch.nn.utils.clip_grad_norm( + self._model.parameters(), self.config["grad_clip"]) diff --git a/python/ray/rllib/a3c/tfpolicy.py b/python/ray/rllib/a3c/tfpolicy.py index e5624a755..d4f089986 100644 --- a/python/ray/rllib/a3c/tfpolicy.py +++ b/python/ray/rllib/a3c/tfpolicy.py @@ -10,8 +10,10 @@ from ray.rllib.a3c.policy import Policy class TFPolicy(Policy): """The policy base class.""" - def __init__(self, ob_space, action_space, name="local", summarize=True): + def __init__(self, ob_space, action_space, config, + name="local", summarize=True): self.local_steps = 0 + self.config = config self.summarize = summarize worker_device = "/job:localhost/replica:0/task:0/cpu:0" self.g = tf.Graph() @@ -52,13 +54,15 @@ class TFPolicy(Policy): delta = self.vf - self.r self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) self.entropy = tf.reduce_sum(self.curr_dist.entropy()) - self.loss = self.pi_loss + 0.5 * self.vf_loss - self.entropy * 0.01 + self.loss = (self.pi_loss + + self.vf_loss * self.config["vf_loss_coeff"] + + self.entropy * self.config["entropy_coeff"]) def setup_gradients(self): grads = tf.gradients(self.loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, 40.0) + self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) grads_and_vars = list(zip(self.grads, self.var_list)) - opt = tf.train.AdamOptimizer(1e-4) + opt = tf.train.AdamOptimizer(self.config["lr"]) self._apply_gradients = opt.apply_gradients(grads_and_vars) def initialize(self): @@ -71,6 +75,7 @@ class TFPolicy(Policy): tf.summary.scalar("model/var_gnorm", tf.global_norm(self.var_list)) self.summary_op = tf.summary.merge_all() + # TODO(rliaw): Can consider exposing these parameters self.sess = tf.Session(graph=self.g, config=tf.ConfigProto( intra_op_parallelism_threads=1, inter_op_parallelism_threads=2)) self.variables = ray.experimental.TensorFlowVariables(self.loss, diff --git a/python/ray/rllib/a3c/torchpolicy.py b/python/ray/rllib/a3c/torchpolicy.py index 4704d2365..b473c41f8 100644 --- a/python/ray/rllib/a3c/torchpolicy.py +++ b/python/ray/rllib/a3c/torchpolicy.py @@ -15,8 +15,10 @@ class TorchPolicy(Policy): The model is a separate object than the policy. This could be changed in the future.""" - def __init__(self, ob_space, action_space, name="local", summarize=True): + def __init__(self, ob_space, action_space, config, + name="local", summarize=True): self.local_steps = 0 + self.config = config self.summarize = summarize self._setup_graph(ob_space, action_space) torch.set_num_threads(2) diff --git a/python/ray/rllib/optimizers/async.py b/python/ray/rllib/optimizers/async.py index f4033dc25..b48fed3f8 100644 --- a/python/ray/rllib/optimizers/async.py +++ b/python/ray/rllib/optimizers/async.py @@ -35,8 +35,7 @@ class AsyncOptimizer(Optimizer): # Note: can't use wait: https://github.com/ray-project/ray/issues/1128 while gradient_queue: with self.wait_timer: - fut, e = gradient_queue[0] - gradient_queue = gradient_queue[1:] + fut, e = gradient_queue.pop(0) gradient = ray.get(fut) if gradient is not None: diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index 80091ba7d..6f33a50dd 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -78,7 +78,10 @@ DEFAULT_CONFIG = { # is detected "tf_debug_inf_or_nan": False, # If True, we write tensorflow logs and checkpoints - "write_logs": True + "write_logs": True, + # Preprocessing for environment + # TODO(rliaw): Convert to function similar to A#c + "preprocessing": {} } @@ -139,7 +142,7 @@ class PPOAgent(Agent): # to guard against the case where all values are equal return (value - value.mean()) / max(1e-4, value.std()) - trajectory["advantages"] = standardized(trajectory["advantages"]) + trajectory.data["advantages"] = standardized(trajectory["advantages"]) rollouts_end = time.time() print("Computing policy (iterations=" + str(config["num_sgd_iter"]) + @@ -147,7 +150,7 @@ class PPOAgent(Agent): names = [ "iter", "total loss", "policy loss", "vf loss", "kl", "entropy"] print(("{:>15}" * len(names)).format(*names)) - trajectory = shuffle(trajectory) + trajectory.data = shuffle(trajectory.data) shuffle_end = time.time() tuples_per_device = model.load_data( trajectory, self.iteration == 0 and config["full_trace_data_load"]) diff --git a/python/ray/rllib/ppo/rollout.py b/python/ray/rllib/ppo/rollout.py index 55bcca859..847e2a9bd 100644 --- a/python/ray/rllib/ppo/rollout.py +++ b/python/ray/rllib/ppo/rollout.py @@ -5,7 +5,7 @@ from __future__ import print_function import numpy as np import ray -from ray.rllib.ppo.utils import concatenate +from ray.rllib.optimizers import SampleBatch def collect_samples(agents, @@ -37,5 +37,5 @@ def collect_samples(agents, trajectories.append(trajectory) observation_filter.update(obs_f) reward_filter.update(rew_f) - return (concatenate(trajectories), np.mean(total_rewards), + return (SampleBatch.concat_samples(trajectories), np.mean(total_rewards), np.mean(trajectory_lengths)) diff --git a/python/ray/rllib/ppo/runner.py b/python/ray/rllib/ppo/runner.py index 46b738ae5..2e4e4899e 100644 --- a/python/ray/rllib/ppo/runner.py +++ b/python/ray/rllib/ppo/runner.py @@ -19,7 +19,7 @@ from ray.rllib.utils.sampler import SyncSampler from ray.rllib.utils.filter import get_filter, MeanStdFilter from ray.rllib.utils.process_rollout import process_rollout from ray.rllib.ppo.loss import ProximalPolicyLoss -from ray.rllib.ppo.utils import concatenate +from ray.rllib.optimizers import SampleBatch # TODO(pcm): Make sure that both observation_filter and reward_filter @@ -227,7 +227,7 @@ class Runner(object): (c.episode_reward, c.episode_length) for c in metrics]) updated_obs_filter = self.sampler.get_obs_filter(flush=True) return ( - concatenate(trajectories), + SampleBatch.concat_samples(trajectories), total_rewards, trajectory_lengths, updated_obs_filter, diff --git a/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml b/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml index 05c6537cd..0f62389c6 100644 --- a/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml +++ b/python/ray/rllib/tuned_examples/pong-a3c-pytorch.yaml @@ -2,16 +2,17 @@ pong-a3c-pytorch-cnn: env: PongDeterministic-v4 run: A3C resources: - cpu: 16 + cpu: 17 driver_cpu_limit: 1 config: num_workers: 16 - num_batches_per_iteration: 1000 batch_size: 20 use_lstm: false use_pytorch: true - model: + preprocessing: grayscale: true zero_mean: false dim: 80 channel_major: true + optimizer: + grads_per_step: 1000 diff --git a/python/ray/rllib/tuned_examples/pong-a3c.yaml b/python/ray/rllib/tuned_examples/pong-a3c.yaml index 03dafe6de..207b703ff 100644 --- a/python/ray/rllib/tuned_examples/pong-a3c.yaml +++ b/python/ray/rllib/tuned_examples/pong-a3c.yaml @@ -2,9 +2,15 @@ pong-a3c: env: PongDeterministic-v4 run: A3C resources: - cpu: 16 + cpu: 17 driver_cpu_limit: 1 config: num_workers: 16 - num_batches_per_iteration: 1000 batch_size: 20 + use_lstm: true + use_pytorch: false + optimizer: + grads_per_step: 1000 + preprocessing: + dim: 42 + channel_major: false diff --git a/python/ray/rllib/utils/process_rollout.py b/python/ray/rllib/utils/process_rollout.py index a066ce83d..123234c5f 100644 --- a/python/ray/rllib/utils/process_rollout.py +++ b/python/ray/rllib/utils/process_rollout.py @@ -4,6 +4,7 @@ from __future__ import print_function import numpy as np import scipy.signal +from ray.rllib.optimizers import SampleBatch def discount(x, gamma): @@ -11,7 +12,15 @@ def discount(x, gamma): def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True): - """Given a rollout, compute its value targets and the advantage.""" + """Given a rollout, compute its value targets and the advantage. + + Args: + rollout (PartialRollout): Partial Rollout Object + reward_filter (Filter): # TODO(rliaw) + + Returns: + SampleBatch (SampleBatch): Object with experience from rollout and + processed rewards.""" traj = {} trajsize = len(rollout.data["actions"]) @@ -35,6 +44,8 @@ def process_rollout(rollout, reward_filter, gamma, lambda_=1.0, use_gae=True): for i in range(traj["advantages"].shape[0]): traj["advantages"][i] = reward_filter(traj["advantages"][i]) + traj["advantages"] = traj["advantages"].copy() + assert all(val.shape[0] == trajsize for val in traj.values()), \ "Rollout stacked incorrectly!" - return traj + return SampleBatch(traj) diff --git a/python/ray/rllib/utils/sampler.py b/python/ray/rllib/utils/sampler.py index 52b5c3d36..89232cc7e 100644 --- a/python/ray/rllib/utils/sampler.py +++ b/python/ray/rllib/utils/sampler.py @@ -148,9 +148,10 @@ class AsyncSampler(threading.Thread): self.policy = policy self._obs_filter = obs_filter self._obs_f_lock = threading.Lock() - self.start() + self.started = False def run(self): + self.started = True try: self._run() except BaseException as e: @@ -213,7 +214,7 @@ class AsyncSampler(threading.Thread): Returns: rollout (PartialRollout): trajectory data (unprocessed) """ - + assert self.started, "Sampler never started running!" rollout = self._pull_batch_from_queue() return rollout