diff --git a/python/ray/rllib/examples/multiagent_mountaincar.py b/python/ray/rllib/examples/multiagent_mountaincar.py index f585dde1f..74f818d7e 100644 --- a/python/ray/rllib/examples/multiagent_mountaincar.py +++ b/python/ray/rllib/examples/multiagent_mountaincar.py @@ -36,18 +36,18 @@ def create_env(env_config): if __name__ == '__main__': register_env(env_name, lambda env_config: create_env(env_config)) config = ppo.DEFAULT_CONFIG.copy() - horizon = 200 - num_cpus = 2 - ray.init(num_cpus=num_cpus, redirect_output=False) + horizon = 10 + num_cpus = 4 + ray.init(num_cpus=num_cpus, redirect_output=True) config["num_workers"] = num_cpus - config["timesteps_per_batch"] = 100 + config["timesteps_per_batch"] = 10 config["num_sgd_iter"] = 10 config["gamma"] = 0.999 config["horizon"] = horizon - config["use_gae"] = True + config["use_gae"] = False config["model"].update({"fcnet_hiddens": [256, 256]}) options = {"multiagent_obs_shapes": [2, 2], - "multiagent_act_shapes": [3, 3], + "multiagent_act_shapes": [1, 1], "multiagent_shared_model": False, "multiagent_fcnet_hiddens": [[32, 32]] * 2} config["model"].update({"custom_options": options}) diff --git a/python/ray/rllib/examples/multiagent_pendulum.py b/python/ray/rllib/examples/multiagent_pendulum.py index 9e629ed07..20cd5d7ac 100644 --- a/python/ray/rllib/examples/multiagent_pendulum.py +++ b/python/ray/rllib/examples/multiagent_pendulum.py @@ -36,11 +36,11 @@ def create_env(env_config): if __name__ == '__main__': register_env(env_name, lambda env_config: create_env(env_config)) config = ppo.DEFAULT_CONFIG.copy() - horizon = 100 - num_cpus = 2 - ray.init(num_cpus=num_cpus, redirect_output=False) + horizon = 10 + num_cpus = 4 + ray.init(num_cpus=num_cpus, redirect_output=True) config["num_workers"] = num_cpus - config["timesteps_per_batch"] = 100 + config["timesteps_per_batch"] = 10 config["num_sgd_iter"] = 10 config["gamma"] = 0.999 config["horizon"] = horizon diff --git a/python/ray/rllib/utils/reshaper.py b/python/ray/rllib/utils/reshaper.py index 37a96ebab..c0687b488 100644 --- a/python/ray/rllib/utils/reshaper.py +++ b/python/ray/rllib/utils/reshaper.py @@ -14,10 +14,10 @@ class Reshaper(object): if isinstance(env_space, list): for space in env_space: # Handle both gym arrays and just lists of inputs length - if hasattr(space, "shape"): - arr_shape = np.asarray(space.shape) - elif hasattr(space, "n"): + if hasattr(space, "n"): arr_shape = np.asarray([1]) # discrete space + elif hasattr(space, "shape"): + arr_shape = np.asarray(space.shape) else: arr_shape = space self.shapes.append(arr_shape)