mirror of
https://github.com/wassname/NALU-pytorch.git
synced 2026-06-27 16:00:06 +08:00
readme++
This commit is contained in:
@@ -11,6 +11,8 @@ This is a PyTorch implementation of [Neural Arithmetic Logic Units](https://arxi
|
||||
## API
|
||||
|
||||
```python
|
||||
from models import *
|
||||
|
||||
# single layer modules
|
||||
NAC(in_features, out_features)
|
||||
NALU(in_features, out_features)
|
||||
@@ -39,12 +41,12 @@ To reproduce "Simple Function Learning Tasks" (Section 4.1), run:
|
||||
```python
|
||||
python function_learning.py
|
||||
```
|
||||
This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest. Also getting `nans` which I'm investigating.)
|
||||
This should generate a text file called `interpolation.txt` with the following results. (Currently only supports interpolation, I'm working on the rest)
|
||||
|
||||
| | Relu6 | None | NAC | NALU |
|
||||
|-------|----------|----------|----------|--------|
|
||||
| a + b | 0.002 | 0.000 | 0.000 | 1.399 |
|
||||
| a - b | 0.046 | 0.000 | 0.000 | 0.224 |
|
||||
| a * b | 83.012 | 99.590 | 98.822 | 12.237 |
|
||||
| a / b | 2245.560 | 2888.195 | 2765.908 | nan |
|
||||
| a ^ 2 | 76.126 | 99.106 | 99.559 | nan |
|
||||
| a + b | 0.002 | 0.004 | 0.001 | 0.017 |
|
||||
| a - b | 0.046 | 0.005 | 0.000 | 0.003 |
|
||||
| a * b | 83.012 | 0.444 | 5.218 | 5.218 |
|
||||
| a / b | 106.441 | 0.338 | 2.096 | 2.096 |
|
||||
| a ^ 2 | 94.103 | 0.630 | 3.871 | 0.196 |
|
||||
|
||||
@@ -13,14 +13,15 @@ NORMALIZE = True
|
||||
NUM_LAYERS = 2
|
||||
HIDDEN_DIM = 2
|
||||
LEARNING_RATE = 1e-3
|
||||
NUM_ITERS = int(8e4)
|
||||
RANGE = [-5, 5]
|
||||
NUM_ITERS = int(7e4)
|
||||
RANGE = [5, 10]
|
||||
ARITHMETIC_FUNCTIONS = {
|
||||
'add': lambda x, y: x + y,
|
||||
'sub': lambda x, y: x - y,
|
||||
'mul': lambda x, y: x * y,
|
||||
'div': lambda x, y: x / y,
|
||||
'squared': lambda x, y: torch.pow(x, 2),
|
||||
'root': lambda x, y: torch.sqrt(x),
|
||||
}
|
||||
|
||||
|
||||
@@ -137,4 +138,3 @@ def main():
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
Relu6 None NAC NALU
|
||||
0.002 0.000 0.000 1.399
|
||||
0.046 0.000 0.000 0.224
|
||||
83.012 99.590 98.822 12.237
|
||||
2245.560 2888.195 2765.908 nan
|
||||
76.126 99.106 99.559 nan
|
||||
Reference in New Issue
Block a user