diff --git a/rllib/BUILD b/rllib/BUILD index 94f15cef7..c98bcc209 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1643,6 +1643,30 @@ py_test( args = ["--as-test", "--torch", "--stop-reward=6.0"] ) +py_test( + name = "examples/complex_struct_space_tf", main = "examples/complex_struct_space.py", + tags = ["examples", "examples_C"], + size = "medium", + srcs = ["examples/complex_struct_space.py"], + args = ["--framework=tf"], +) + +py_test( + name = "examples/complex_struct_space_tf_eager", main = "examples/complex_struct_space.py", + tags = ["examples", "examples_C"], + size = "medium", + srcs = ["examples/complex_struct_space.py"], + args = ["--framework=tfe"], +) + +py_test( + name = "examples/complex_struct_space_torch", main = "examples/complex_struct_space.py", + tags = ["examples", "examples_C"], + size = "medium", + srcs = ["examples/complex_struct_space.py"], + args = ["--framework=torch"], +) + py_test( name = "examples/custom_env_tf", main = "examples/custom_env.py", @@ -1697,30 +1721,6 @@ py_test( args = ["--torch", "--stop-iters=1", "--num-cpus=4"] ) -py_test( - name = "examples/complex_struct_space_tf", main = "examples/complex_struct_space.py", - tags = ["examples", "examples_C"], - size = "medium", - srcs = ["examples/complex_struct_space.py"], - args = ["--framework=tf"], -) - -py_test( - name = "examples/complex_struct_space_tf_eager", main = "examples/complex_struct_space.py", - tags = ["examples", "examples_C"], - size = "medium", - srcs = ["examples/complex_struct_space.py"], - args = ["--framework=tfe"], -) - -py_test( - name = "examples/complex_struct_space_torch", main = "examples/complex_struct_space.py", - tags = ["examples", "examples_C"], - size = "medium", - srcs = ["examples/complex_struct_space.py"], - args = ["--framework=torch"], -) - py_test( name = "examples/custom_keras_model_a2c", main = "examples/custom_keras_model.py", @@ -1788,6 +1788,15 @@ py_test( args = ["--stop-iters=2"] ) +py_test( + name = "examples/custom_observation_filters", + main = "examples/custom_observation_filters.py", + tags = ["examples", "examples_C"], + size = "small", + srcs = ["examples/custom_observation_filters.py"], + args = ["--stop-iters=2"] +) + py_test( name = "examples/custom_rnn_model_repeat_after_me_tf", main = "examples/custom_rnn_model.py", diff --git a/rllib/examples/custom_observation_filters.py b/rllib/examples/custom_observation_filters.py new file mode 100644 index 000000000..95c142354 --- /dev/null +++ b/rllib/examples/custom_observation_filters.py @@ -0,0 +1,141 @@ +"""Example of a custom observation filter + +This example shows: + - using a custom observation filter + +""" +import argparse + +import numpy as np +import ray +from ray import tune +from ray.rllib.utils.filter import Filter +from ray.rllib.utils.framework import try_import_tf + +tf1, tf, tfv = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") +parser.add_argument("--stop-iters", type=int, default=200) + + +class SimpleRollingStat: + def __init__(self, n=0, m=0, s=0): + self._n = n + self._m = m + self._s = s + + def copy(self): + return SimpleRollingStat(self._n, self._m, self._s) + + def push(self, x): + self._n += 1 + delta = x - self._m + self._m += delta / self._n + self._s += delta * delta * (self._n - 1) / self._n + + def update(self, other): + n1 = self._n + n2 = other._n + n = n1 + n2 + if n == 0: + return + + delta = self._m - other._m + delta2 = delta * delta + + self._n = n + self._m = (n1 * self._m + n2 * other._m) / n + self._s = self._s + other._s + delta2 * n1 * n2 / n + + @property + def n(self): + return self._n + + @property + def mean(self): + return self._m + + @property + def var(self): + return self._s / (self._n - 1) if self._n > 1 else np.square(self._m) + + @property + def std(self): + return np.sqrt(self.var) + + +class CustomFilter(Filter): + """ + Filter that normalizes by using a single mean + and std sampled from all obs inputs + """ + is_concurrent = False + + def __init__(self, shape): + self.rs = SimpleRollingStat() + self.buffer = SimpleRollingStat() + self.shape = shape + + def clear_buffer(self): + self.buffer = SimpleRollingStat(self.shape) + + def apply_changes(self, other, with_buffer=False): + self.rs.update(other.buffer) + if with_buffer: + self.buffer = other.buffer.copy() + + def copy(self): + other = CustomFilter(self.shape) + other.sync(self) + return other + + def as_serializable(self): + return self.copy() + + def sync(self, other): + assert other.shape == self.shape, "Shapes don't match!" + self.rs = other.rs.copy() + self.buffer = other.buffer.copy() + + def __call__(self, x, update=True): + x = np.asarray(x) + if update: + if len(x.shape) == len(self.shape) + 1: + # The vectorized case. + for i in range(x.shape[0]): + self.push_stats(x[i], (self.rs, self.buffer)) + else: + # The unvectorized case. + self.push_stats(x, (self.rs, self.buffer)) + x = x - self.rs.mean + x = x / (self.rs.std + 1e-8) + return x + + @staticmethod + def push_stats(vector, buffers): + for x in vector: + for buffer in buffers: + buffer.push(x) + + def __repr__(self): + return f"CustomFilter({self.shape}, {self.rs}, {self.buffer})" + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + config = { + "env": "CartPole-v0", + "observation_filter": lambda size: CustomFilter(size), + "num_workers": 0, + } + + results = tune.run( + "PG", + args.run, + config=config, + stop={"training_iteration": args.stop_iters}) + + ray.shutdown() diff --git a/rllib/utils/filter.py b/rllib/utils/filter.py index 46f829017..683503f30 100644 --- a/rllib/utils/filter.py +++ b/rllib/utils/filter.py @@ -277,5 +277,7 @@ def get_filter(filter_config, shape): return ConcurrentMeanStdFilter(shape, clip=None) elif filter_config == "NoFilter": return NoFilter() + elif callable(filter_config): + return filter_config(shape) else: raise Exception("Unknown observation_filter: " + str(filter_config))