This commit is contained in:
Kevin
2018-08-04 03:16:03 -07:00
parent 438d4d8018
commit 92b0f2e9c3
3 changed files with 11 additions and 15 deletions
+8 -6
View File
@@ -11,6 +11,8 @@ This is a PyTorch implementation of [Neural Arithmetic Logic Units](https://arxi
## API
```python
from models import *
# single layer modules
NAC(in_features, out_features)
NALU(in_features, out_features)
@@ -39,12 +41,12 @@ To reproduce "Simple Function Learning Tasks" (Section 4.1), run:
```python
python function_learning.py
```
This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest. Also getting `nans` which I'm investigating.)
This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest)
| | Relu6 | None | NAC | NALU |
|-------|----------|----------|----------|--------|
| a + b | 0.002 | 0.000 | 0.000 | 1.399 |
| a - b | 0.046 | 0.000 | 0.000 | 0.224 |
| a * b | 83.012 | 99.590 | 98.822 | 12.237 |
| a / b | 2245.560 | 2888.195 | 2765.908 | nan |
| a ^ 2 | 76.126 | 99.106 | 99.559 | nan |
| a + b | 0.002 | 0.004 | 0.001 | 0.017 |
| a - b | 0.046 | 0.005 | 0.000 | 0.003 |
| a * b | 83.012 | 0.444 | 5.218 | 5.218 |
| a / b | 106.441 | 0.338 | 2.096 | 2.096 |
| a ^ 2 | 94.103 | 0.630 | 3.871 | 0.196 |
+3 -3
View File
@@ -13,14 +13,15 @@ NORMALIZE = True
NUM_LAYERS = 2
HIDDEN_DIM = 2
LEARNING_RATE = 1e-3
NUM_ITERS = int(8e4)
RANGE = [-5, 5]
NUM_ITERS = int(7e4)
RANGE = [5, 10]
ARITHMETIC_FUNCTIONS = {
'add': lambda x, y: x + y,
'sub': lambda x, y: x - y,
'mul': lambda x, y: x * y,
'div': lambda x, y: x / y,
'squared': lambda x, y: torch.pow(x, 2),
'root': lambda x, y: torch.sqrt(x),
}
@@ -137,4 +138,3 @@ def main():
if __name__ == '__main__':
main()
-6
View File
@@ -1,6 +0,0 @@
Relu6 None NAC NALU
0.002 0.000 0.000 1.399
0.046 0.000 0.000 0.224
83.012 99.590 98.822 12.237
2245.560 2888.195 2765.908 nan
76.126 99.106 99.559 nan