diff --git a/README.md b/README.md index a2d4024..a89a8ca 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ This is a PyTorch implementation of [Neural Arithmetic Logic Units](https://arxi ## API ```python +from models import * + # single layer modules NAC(in_features, out_features) NALU(in_features, out_features) @@ -39,12 +41,12 @@ To reproduce "Simple Function Learning Tasks" (Section 4.1), run: ```python python function_learning.py ``` -This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest. Also getting `nans` which I'm investigating.) +This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest) | | Relu6 | None | NAC | NALU | |-------|----------|----------|----------|--------| -| a + b | 0.002 | 0.000 | 0.000 | 1.399 | -| a - b | 0.046 | 0.000 | 0.000 | 0.224 | -| a * b | 83.012 | 99.590 | 98.822 | 12.237 | -| a / b | 2245.560 | 2888.195 | 2765.908 | nan | -| a ^ 2 | 76.126 | 99.106 | 99.559 | nan | +| a + b | 0.002 | 0.004 | 0.001 | 0.017 | +| a - b | 0.046 | 0.005 | 0.000 | 0.003 | +| a * b | 83.012 | 0.444 | 5.218 | 5.218 | +| a / b | 106.441 | 0.338 | 2.096 | 2.096 | +| a ^ 2 | 94.103 | 0.630 | 3.871 | 0.196 | diff --git a/function_learning.py b/function_learning.py index ee2bce5..3cb8b69 100644 --- a/function_learning.py +++ b/function_learning.py @@ -13,14 +13,15 @@ NORMALIZE = True NUM_LAYERS = 2 HIDDEN_DIM = 2 LEARNING_RATE = 1e-3 -NUM_ITERS = int(8e4) -RANGE = [-5, 5] +NUM_ITERS = int(7e4) +RANGE = [5, 10] ARITHMETIC_FUNCTIONS = { 'add': lambda x, y: x + y, 'sub': lambda x, y: x - y, 'mul': lambda x, y: x * y, 'div': lambda x, y: x / y, 'squared': lambda x, y: torch.pow(x, 2), + 'root': lambda x, y: torch.sqrt(x), } @@ -137,4 +138,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/results/interpolation.txt b/results/interpolation.txt deleted file mode 100644 index 8e9ec43..0000000 --- a/results/interpolation.txt +++ /dev/null @@ -1,6 +0,0 @@ -Relu6 None NAC NALU -0.002 0.000 0.000 1.399 -0.046 0.000 0.000 0.224 -83.012 99.590 98.822 12.237 -2245.560 2888.195 2765.908 nan -76.126 99.106 99.559 nan