add changes for auto_main, tree_manager, and utils/ranking (#786)

* add changes for auto_main, tree_manager, and utils/ranking

* pre-commit changes

Co-authored-by: Alexander Mattick <alex.mattick@fau.de>
This commit is contained in:
dhug
2023-01-17 02:27:21 -05:00
committed by GitHub
parent e3387b43b8
commit 8b30c7b68e
3 changed files with 408 additions and 0 deletions
+18
View File
@@ -13,6 +13,7 @@ from oasst_backend.models import Message, MessageReaction, MessageTreeState, Tas
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlalchemy.sql import text
@@ -587,6 +588,7 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.RANKING)
return True
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
mts: MessageTreeState
@@ -603,8 +605,24 @@ class TreeManager:
return False
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
self.update_message_ranks(rankings_by_message)
return True
@managed_tx_method(CommitMode.COMMIT)
def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None:
for parent_msg_id, ranking in rankings_by_message.items():
sorted_messages = []
for msg_reaction in ranking:
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
logger.debug(f"SORTED MESSAGE {sorted_messages}")
consensus = ranked_pairs(sorted_messages)
logger.debug(f"CONSENSUS: {consensus}\n\n")
for rank, uuid in enumerate(consensus):
# set rank for each message_id for Message rows
msg = self.db.query(Message).filter(Message.id == uuid).one()
msg.rank = rank
self.db.add(msg)
def _calculate_acceptance(self, labels: list[TextLabels]):
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
+140
View File
@@ -0,0 +1,140 @@
from typing import List
import numpy as np
def head_to_head_votes(ranks: List[List[int]]):
tallies = np.zeros((len(ranks[0]), len(ranks[0])))
names = sorted(ranks[0])
ranks = np.array(ranks)
# we want the sorted indices
ranks = np.argsort(ranks, axis=1)
for i in range(ranks.shape[1]):
for j in range(i + 1, ranks.shape[1]):
# now count the cases someone voted for i over j
over_j = np.sum(ranks[:, i] < ranks[:, j])
over_i = np.sum(ranks[:, j] < ranks[:, i])
tallies[i, j] = over_j
# tallies[i,j] = over_i
tallies[j, i] = over_i
# tallies[j,i] = over_j
return tallies, names
def cycle_detect(pairs):
"""Recursively detect cylces by removing condorcet losers until either only one pair is left or condorcet loosers no longer exist
This method upholds the invariant that in a ranking for all a,b either a>b or b>a for all a,b.
Returns
-------
out : False if the pairs do not contain a cycle, True if the pairs contain a cycle
"""
# get all condorcet losers (pairs that loose to all other pairs)
# idea: filter all losers that are never winners
# print("pairs", pairs)
if len(pairs) <= 1:
return False
losers = [c_lose for c_lose in np.unique(pairs[:, 1]) if c_lose not in pairs[:, 0]]
if len(losers) == 0:
# if we recursively removed pairs, and at some point we did not have
# a condorcet loser, that means everything is both a winner and loser,
# yielding at least one (winner,loser), (loser,winner) pair
return True
new = []
for p in pairs:
if p[1] not in losers:
new.append(p)
return cycle_detect(np.array(new))
def get_winner(pairs):
"""
This returns _one_ concordant winner.
It could be that there are multiple concordant winners, but in our case
since we are interested in a ranking, we have to choose one at random.
"""
losers = np.unique(pairs[:, 1]).astype(int)
winners = np.unique(pairs[:, 0]).astype(int)
for w in winners:
if w not in losers:
return w
def get_ranking(pairs):
"""
Abuses concordance property to get a (not necessarily unqiue) ranking.
The lack of uniqueness is due to the potential existence of multiple
equally ranked winners. We have to pick one, which is where
the non-uniqueness comes from
"""
if len(pairs) == 1:
return list(pairs[0])
w = get_winner(pairs)
# now remove the winner from the list of pairs
p_new = np.array([(a, b) for a, b in pairs if a != w])
return [w] + get_ranking(p_new)
def ranked_pairs(ranks: List[List[int]]):
"""
Expects a list of rankings for an item like:
[("w","x","z","y") for _ in range(3)]
+ [("w","y","x","z") for _ in range(2)]
+ [("x","y","z","w") for _ in range(4)]
+ [("x","z","w","y") for _ in range(5)]
+ [("y","w","x","z") for _ in range(1)]
This code is quite brain melting, but the idea is the following:
1. create a head-to-head matrix that tallies up all win-lose combinations of preferences
2. take all combinations that win more than they loose and sort those by how often they win
3. use that to create an (implicit) directed graph
4. recursively extract nodes from the graph that do not have incoming edges
5. said recursive list is the ranking
"""
tallies, names = head_to_head_votes(ranks)
tallies = tallies - tallies.T
# print(tallies)
# note: the resulting tally matrix should be skew-symmetric
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
sorted_majorities = []
for i in range(len(ranks[0])):
for j in range(len(ranks[i])):
if tallies[i, j] > 0:
sorted_majorities.append((i, j, tallies[i, j]))
# we don't explicitly deal with tied majorities here
sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True))
# now do lock ins
lock_ins = []
for (x, y, _) in sorted_majorities:
# invariant: lock_ins has no cycles here
lock_ins.append((x, y))
# print("lock ins are now",np.array(lock_ins))
if cycle_detect(np.array(lock_ins)):
# print("backup: cycle detected")
# if there's a cycle, delete the new addition and continue
lock_ins = lock_ins[:-1]
# now simply return all winners in order, and attach the losers
# to the back. This is because the overall loser might not be unique
# and (by concordance property) may never exist in any winning set to begin with.
# (otherwise he would either not be the loser, or cycles exist!)
# Since there could be multiple overall losers, we just return them in any order
# as we are unable to find a closer ranking
numerical_ranks = np.array(get_ranking(np.array(lock_ins))).astype(int)
conversion = [names[n] for n in numerical_ranks]
return conversion
if __name__ == "__main__":
ranks = (
[("w", "x", "z", "y") for _ in range(1)]
+ [("w", "y", "x", "z") for _ in range(2)]
# + [("x","y","z","w") for _ in range(4)]
+ [("x", "z", "w", "y") for _ in range(5)]
+ [("y", "w", "x", "z") for _ in range(1)]
# [("y","z","w","x") for _ in range(1000)]
)
rp = ranked_pairs(ranks)
print(rp)
+250
View File
@@ -0,0 +1,250 @@
"""Simple REPL frontend."""
import http
import random
import requests
import typer
app = typer.Typer()
# debug constants
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
def _random_message_id():
return str(random.randint(1000, 9999))
def _render_message(message: dict) -> str:
"""Render a message to the user."""
if message["is_assistant"]:
return f"Assistant: {message['text']}"
return f"Prompter: {message['text']}"
@app.command()
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
"""automates tasks"""
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()
if response.status_code == http.HTTPStatus.NO_CONTENT:
return None
return response.json()
def gen_random_text():
return " ".join([random.choice(["hello", "world", "foo", "bar"]) for _ in range(10)])
def gen_random_ranking(messages):
"""rank messages randomly and return list of indexes in order of rank randomly"""
print("Ranking")
print(messages)
print(len(messages))
ranks = [i for i in range(len(messages))]
shuffled = random.shuffle(ranks)
print(ranks)
print(shuffled)
return ranks
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
q = 0
while tasks:
task = tasks.pop(0)
print(task)
match (task["type"]):
case "initial_prompt":
typer.echo("Please provide an initial prompt to the assistant.")
if task["hint"]:
typer.echo(f"Hint: {task['hint']}")
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
prompt = gen_random_text()
user_message_id = _random_message_id()
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": prompt,
"user": USER,
},
)
tasks.append(new_task)
case "label_initial_prompt":
typer.echo("Label the following prompt:")
typer.echo(task["prompt"])
# acknowledge task
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
labels_dict = None
if task["mode"] == "simple" and len(valid_labels) == 1:
answer = random.choice([True, False])
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
while labels_dict is None:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send labels
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["prompt"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "prompter_reply":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": gen_random_text(),
"user": USER,
},
)
tasks.append(new_task)
case "assistant_reply":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_message",
"message_id": message_id,
"task_id": task["id"],
"user_message_id": user_message_id,
"text": gen_random_text(),
"user": USER,
},
)
tasks.append(new_task)
case "rank_prompter_replies" | "rank_assistant_replies":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
ranking = gen_random_ranking(task["replies"])
print(ranking)
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "message_ranking",
"message_id": message_id,
"task_id": task["id"],
"ranking": ranking,
"user": USER,
},
)
tasks.append(new_task)
case "rank_initial_prompts":
# acknowledge task
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
# send interaction
ranking = gen_random_ranking(task["prompots"])
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "message_ranking",
"message_id": message_id,
"ranking": ranking,
"user": USER,
},
)
tasks.append(new_task)
case "label_prompter_reply" | "label_assistant_reply":
# acknowledge task
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
typer.echo(_render_message(message))
typer.echo("Label the following reply:")
typer.echo(task["reply"])
message_id = _random_message_id()
user_message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
valid_labels = task["valid_labels"]
labels_dict = None
if task["mode"] == "simple" and len(valid_labels) == 1:
answer = random.choice([True, False])
labels_dict = {valid_labels[0]: 1 if answer else 0}
else:
while labels_dict is None:
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
if all([label in valid_labels for label in labels]):
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
else:
invalid_labels = [label for label in labels if label not in valid_labels]
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_labels",
"message_id": task["message_id"],
"task_id": task["id"],
"text": task["reply"],
"labels": labels_dict,
"user": USER,
},
)
tasks.append(new_task)
case "task_done":
typer.echo("Task done!")
# rerun with new task slected from above cases
# add a new task
q += 1
if q == 10:
typer.echo("Task done!")
break
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
#
case _:
typer.echo(f"Unknown task type {task['type']}")
# rerun with new task slected from above cases
if __name__ == "__main__":
app()