mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
19 lines
723 B
Python
19 lines
723 B
Python
import os
|
|
import json
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
from datasets import _rocstories
|
|
|
|
def rocstories(data_dir, pred_path, log_path):
|
|
preds = pd.read_csv(pred_path, delimiter='\t')['prediction'].values.tolist()
|
|
_, _, _, labels = _rocstories(os.path.join(data_dir, 'erotic_gutenberg_VAL.csv'))
|
|
test_accuracy = accuracy_score(labels, preds)*100.
|
|
logs = [json.loads(line) for line in open(log_path)][1:]
|
|
best_validation_index = np.argmax([log['va_acc'] for log in logs])
|
|
valid_accuracy = logs[best_validation_index]['va_acc']
|
|
print('ROCStories Valid Accuracy: %.2f'%(valid_accuracy))
|
|
print('ROCStories Test Accuracy: %.2f'%(test_accuracy))
|