Files
ray/python/ray/experimental/sgd/examples/tensorflow_train_example.py
T
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

132 lines
3.4 KiB
Python

import argparse
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
import ray
from ray import tune
from ray.experimental.sgd.tf.tf_trainer import TFTrainer, TFTrainable
NUM_TRAIN_SAMPLES = 1000
NUM_TEST_SAMPLES = 400
def create_config(batch_size):
return {
"batch_size": batch_size,
"fit_config": {
"steps_per_epoch": NUM_TRAIN_SAMPLES // batch_size
},
"evaluate_config": {
"steps": NUM_TEST_SAMPLES // batch_size,
}
}
def linear_dataset(a=2, size=1000):
x = np.random.rand(size)
y = x / 2
x = x.reshape((-1, 1))
y = y.reshape((-1, 1))
return x, y
def simple_dataset(config):
batch_size = config["batch_size"]
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
batch_size)
test_dataset = test_dataset.repeat().batch(batch_size)
return train_dataset, test_dataset
def simple_model(config):
model = Sequential([Dense(10, input_shape=(1, )), Dense(1)])
model.compile(
optimizer="sgd",
loss="mean_squared_error",
metrics=["mean_squared_error"])
return model
def train_example(num_replicas=1, batch_size=128, use_gpu=False):
trainer = TFTrainer(
model_creator=simple_model,
data_creator=simple_dataset,
num_replicas=num_replicas,
use_gpu=use_gpu,
verbose=True,
config=create_config(batch_size))
train_stats1 = trainer.train()
train_stats1.update(trainer.validate())
print(train_stats1)
train_stats2 = trainer.train()
train_stats2.update(trainer.validate())
print(train_stats2)
val_stats = trainer.validate()
print(val_stats)
print("success!")
def tune_example(num_replicas=1, use_gpu=False):
config = {
"model_creator": tune.function(simple_model),
"data_creator": tune.function(simple_dataset),
"num_replicas": num_replicas,
"use_gpu": use_gpu,
"trainer_config": create_config(batch_size=128)
}
analysis = tune.run(
TFTrainable,
num_samples=2,
config=config,
stop={"training_iteration": 2},
verbose=1)
return analysis.get_best_config(metric="validation_loss", mode="min")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address",
required=False,
type=str,
help="the address to use for Ray")
parser.add_argument(
"--num-replicas",
"-n",
type=int,
default=1,
help="Sets number of replicas for training.")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Enables GPU training")
parser.add_argument(
"--tune", action="store_true", default=False, help="Tune training")
args, _ = parser.parse_known_args()
ray.init(address=args.address)
if args.tune:
tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
else:
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)