diff --git a/README.md b/README.md index 23940e4..8d333e2 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ [PyTorch](https://github.com/pytorch/pytorch.git) version of [_Torch for Numpy users_](https://github.com/torch/torch7/wiki/Torch-for-Numpy-users). - ## Types | Numpy | PyTorch | @@ -19,6 +18,8 @@ ## Constructors +### Ones and zeros + | Numpy | PyTorch | |:-------------------|:---------------------------------------| | `np.empty((2, 3))` | `torch.Tensor(2, 3)` | diff --git a/conversions.yaml b/conversions.yaml index cdef6b9..3e81a8d 100644 --- a/conversions.yaml +++ b/conversions.yaml @@ -15,22 +15,23 @@ types: pytorch: torch.LongTensor constructors: - - numpy: np.empty((2, 3)) - pytorch: torch.Tensor(2, 3) - - numpy: np.empty_like(x) - pytorch: x.new(x.size()).type(x.type()) - - numpy: np.eye - pytorch: torch.eye - - numpy: np.identity - pytorch: torch.eye - - numpy: np.ones - pytorch: torch.ones - - numpy: np.ones_like - pytorch: torch.ones(x.size()).type(x.type()) - - numpy: np.zeros - pytorch: torch.zeros - - numpy: np.zeros_like - pytorch: torch.zeros(x.size()).type(x.type()) + ones and zeros: + - numpy: np.empty((2, 3)) + pytorch: torch.Tensor(2, 3) + - numpy: np.empty_like(x) + pytorch: x.new(x.size()).type(x.type()) + - numpy: np.eye + pytorch: torch.eye + - numpy: np.identity + pytorch: torch.eye + - numpy: np.ones + pytorch: torch.ones + - numpy: np.ones_like + pytorch: torch.ones(x.size()).type(x.type()) + - numpy: np.zeros + pytorch: torch.zeros + - numpy: np.zeros_like + pytorch: torch.zeros(x.size()).type(x.type()) # - numpy: x.astype(np.int32) # pytorch: x.type(torch.IntTensor) diff --git a/generate_readme.py b/generate_readme.py index cec8d16..66978f5 100755 --- a/generate_readme.py +++ b/generate_readme.py @@ -21,6 +21,26 @@ TEMPLATE = '''\ here = osp.dirname(osp.abspath(__file__)) +def get_section(title, data, h=2): + if not isinstance(data, list): + content = '%s %s\n\n' % ('#' * h, title.capitalize()) + for sub_title, sub_data in data.items(): + content += get_section(sub_title, sub_data, h=h+1) + return content + + headers = ['Numpy', 'PyTorch'] + rows = [] + for d in data: + rows.append([ + '`' + d['numpy'] + '`', + '`' + d['pytorch'] + '`', + ]) + + content = '%s %s\n\n' % ('#' * h, title.capitalize()) + content += tabulate.tabulate(rows, headers=headers, tablefmt='pipe') + '\n' + return content + + def get_contents(): # keep order in yaml file yaml.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, @@ -29,23 +49,11 @@ def get_contents(): yaml_file = osp.join(here, 'conversions.yaml') data = yaml.load(open(yaml_file)) - contents = '' - for section, data in data.items(): - headers = ['Numpy', 'PyTorch'] - rows = [] - for d in data: - rows.append([ - '`' + d['numpy'] + '`', - '`' + d['pytorch'] + '`', - ]) - contents += ''' -## {title} - -{table}\n'''.format( - title=section.capitalize(), - table=tabulate.tabulate(rows, headers=headers, tablefmt='pipe'), - ) - return contents + contents = [] + for title, data in data.items(): + section = get_section(title, data) + contents.append(section) + return '\n'.join(contents) print(TEMPLATE.format(contents=get_contents()))