Files
ray/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py
T
Eric Liang 8aa56c12e6 [rllib] Document "v2" APIs (#2316)
* re

* wip

* wip

* a3c working

* torch support

* pg works

* lint

* rm v2

* consumer id

* clean up pg

* clean up more

* fix python 2.7

* tf session management

* docs

* dqn wip

* fix compile

* dqn

* apex runs

* up

* impotrs

* ddpg

* quotes

* fix tests

* fix last r

* fix tests

* lint

* pass checkpoint restore

* kwar

* nits

* policy graph

* fix yapf

* com

* class

* pyt

* vectorization

* update

* test cpe

* unit test

* fix ddpg2

* changes

* wip

* args

* faster test

* common

* fix

* add alg option

* batch mode and policy serving

* multi serving test

* todo

* wip

* serving test

* doc async env

* num envs

* comments

* thread

* remove init hook

* update

* fix ppo

* comments1

* fix

* updates

* add jenkins tests

* fix

* fix pytorch

* fix

* fixes

* fix a3c policy

* fix squeeze

* fix trunc on apex

* fix squeezing for real

* update

* remove horizon test for now

* multiagent wip

* update

* fix race condition

* fix ma

* t

* doc

* st

* wip

* example

* wip

* working

* cartpole

* wip

* batch wip

* fix bug

* make other_batches None default

* working

* debug

* nit

* warn

* comments

* fix ppo

* fix obs filter

* update

* wip

* tf

* update

* fix

* cleanup

* cleanup

* spacing

* model

* fix

* dqn

* fix ddpg

* doc

* keep names

* update

* fix

* com

* docs

* clarify model outputs

* Update torch_policy_graph.py

* fix obs filter

* pass thru worker index

* fix

* rename

* vlad torch comments

* fix log action

* debug name

* fix lstm

* remove unused ddpg net

* remove conv net

* revert lstm

* wip

* wip

* cast

* wip

* works

* fix a3c

* works

* lstm util test

* doc

* clean up

* update

* fix lstm check

* move to end

* fix sphinx

* fix cmd

* remove bad doc

* envs

* vec

* doc prep

* models

* rl

* alg

* up

* clarify

* copy

* async sa

* fix

* comments

* fix a3c conf

* tune lstm

* fix reshape

* fix

* back to 16

* tuned a3c update

* update

* tuned

* optional

* merge

* wip

* fix up

* move pg class

* rename env

* wip

* update

* tip

* alg

* readme

* fix catalog

* readme

* doc

* context

* remove prep

* comma

* add env

* link to paper

* paper

* update

* rnn

* update

* wip

* clean up ev creation

* fix

* fix

* fix

* fix lint

* up

* no comma

* ma

* Update run_multi_node_tests.sh

* fix

* sphinx is stupid

* sphinx is stupid

* clarify torch graph

* no horizon

* fix config

* sb

* Update test_optimizers.py
2018-07-01 00:05:08 -07:00

60 lines
1.6 KiB
Python

""" Run script for multiagent pendulum env. Each agent outputs a
torque which is summed to form the total torque. This is a
continuous multiagent example
"""
import gym
from gym.envs.registration import register
import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.registry import register_env
env_name = "MultiAgentPendulumEnv"
env_version_num = 0
env_name = env_name + '-v' + str(env_version_num)
def pass_params_to_gym(env_name):
global env_version_num
register(
id=env_name,
entry_point=(
"ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:"
"MultiAgentPendulumEnv"),
max_episode_steps=100,
kwargs={}
)
def create_env(env_config):
pass_params_to_gym(env_name)
env = gym.envs.make(env_name)
return env
if __name__ == '__main__':
register_env(env_name, lambda env_config: create_env(env_config))
config = ppo.DEFAULT_CONFIG.copy()
horizon = 10
num_cpus = 4
ray.init(num_cpus=num_cpus, redirect_output=True)
config["num_workers"] = num_cpus
config["timesteps_per_batch"] = 10
config["sgd_batchsize"] = 10
config["num_sgd_iter"] = 10
config["gamma"] = 0.999
config["horizon"] = horizon
config["use_gae"] = True
config["model"].update({"fcnet_hiddens": [256, 256]})
options = {"multiagent_obs_shapes": [3, 3],
"multiagent_act_shapes": [1, 1],
"multiagent_shared_model": True,
"multiagent_fcnet_hiddens": [[32, 32]] * 2}
config["model"].update({"custom_options": options})
alg = ppo.PPOAgent(env=env_name, config=config)
for i in range(1):
alg.train()