From 8b30c7b68ea3196a87041f892fa619feed5a1220 Mon Sep 17 00:00:00 2001 From: dhug <38571110+danielpatrickhug@users.noreply.github.com> Date: Tue, 17 Jan 2023 02:27:21 -0500 Subject: [PATCH] add changes for auto_main, tree_manager, and utils/ranking (#786) * add changes for auto_main, tree_manager, and utils/ranking * pre-commit changes Co-authored-by: Alexander Mattick --- backend/oasst_backend/tree_manager.py | 18 ++ backend/oasst_backend/utils/ranking.py | 140 ++++++++++++++ text-frontend/auto_main.py | 250 +++++++++++++++++++++++++ 3 files changed, 408 insertions(+) create mode 100644 backend/oasst_backend/utils/ranking.py create mode 100644 text-frontend/auto_main.py diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 225b0146..a2c85940 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -13,6 +13,7 @@ from oasst_backend.models import Message, MessageReaction, MessageTreeState, Tas from oasst_backend.prompt_repository import PromptRepository from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI +from oasst_backend.utils.ranking import ranked_pairs from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlalchemy.sql import text @@ -587,6 +588,7 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.RANKING) return True + @managed_tx_method(CommitMode.COMMIT) def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_scoring_state({message_tree_id=})") mts: MessageTreeState @@ -603,8 +605,24 @@ class TreeManager: return False self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) + self.update_message_ranks(rankings_by_message) return True + @managed_tx_method(CommitMode.COMMIT) + def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None: + for parent_msg_id, ranking in rankings_by_message.items(): + sorted_messages = [] + for msg_reaction in ranking: + sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids) + logger.debug(f"SORTED MESSAGE {sorted_messages}") + consensus = ranked_pairs(sorted_messages) + logger.debug(f"CONSENSUS: {consensus}\n\n") + for rank, uuid in enumerate(consensus): + # set rank for each message_id for Message rows + msg = self.db.query(Message).filter(Message.id == uuid).one() + msg.rank = rank + self.db.add(msg) + def _calculate_acceptance(self, labels: list[TextLabels]): # calculate acceptance based on spam label return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) diff --git a/backend/oasst_backend/utils/ranking.py b/backend/oasst_backend/utils/ranking.py new file mode 100644 index 00000000..f6e7a31e --- /dev/null +++ b/backend/oasst_backend/utils/ranking.py @@ -0,0 +1,140 @@ +from typing import List + +import numpy as np + + +def head_to_head_votes(ranks: List[List[int]]): + tallies = np.zeros((len(ranks[0]), len(ranks[0]))) + names = sorted(ranks[0]) + ranks = np.array(ranks) + # we want the sorted indices + ranks = np.argsort(ranks, axis=1) + for i in range(ranks.shape[1]): + for j in range(i + 1, ranks.shape[1]): + # now count the cases someone voted for i over j + over_j = np.sum(ranks[:, i] < ranks[:, j]) + over_i = np.sum(ranks[:, j] < ranks[:, i]) + tallies[i, j] = over_j + # tallies[i,j] = over_i + tallies[j, i] = over_i + # tallies[j,i] = over_j + return tallies, names + + +def cycle_detect(pairs): + """Recursively detect cylces by removing condorcet losers until either only one pair is left or condorcet loosers no longer exist + This method upholds the invariant that in a ranking for all a,b either a>b or b>a for all a,b. + + + Returns + ------- + out : False if the pairs do not contain a cycle, True if the pairs contain a cycle + + + """ + # get all condorcet losers (pairs that loose to all other pairs) + # idea: filter all losers that are never winners + # print("pairs", pairs) + if len(pairs) <= 1: + return False + losers = [c_lose for c_lose in np.unique(pairs[:, 1]) if c_lose not in pairs[:, 0]] + if len(losers) == 0: + # if we recursively removed pairs, and at some point we did not have + # a condorcet loser, that means everything is both a winner and loser, + # yielding at least one (winner,loser), (loser,winner) pair + return True + + new = [] + for p in pairs: + if p[1] not in losers: + new.append(p) + return cycle_detect(np.array(new)) + + +def get_winner(pairs): + """ + This returns _one_ concordant winner. + It could be that there are multiple concordant winners, but in our case + since we are interested in a ranking, we have to choose one at random. + """ + losers = np.unique(pairs[:, 1]).astype(int) + winners = np.unique(pairs[:, 0]).astype(int) + for w in winners: + if w not in losers: + return w + + +def get_ranking(pairs): + """ + Abuses concordance property to get a (not necessarily unqiue) ranking. + The lack of uniqueness is due to the potential existence of multiple + equally ranked winners. We have to pick one, which is where + the non-uniqueness comes from + """ + if len(pairs) == 1: + return list(pairs[0]) + w = get_winner(pairs) + # now remove the winner from the list of pairs + p_new = np.array([(a, b) for a, b in pairs if a != w]) + return [w] + get_ranking(p_new) + + +def ranked_pairs(ranks: List[List[int]]): + """ + Expects a list of rankings for an item like: + [("w","x","z","y") for _ in range(3)] + + [("w","y","x","z") for _ in range(2)] + + [("x","y","z","w") for _ in range(4)] + + [("x","z","w","y") for _ in range(5)] + + [("y","w","x","z") for _ in range(1)] + This code is quite brain melting, but the idea is the following: + 1. create a head-to-head matrix that tallies up all win-lose combinations of preferences + 2. take all combinations that win more than they loose and sort those by how often they win + 3. use that to create an (implicit) directed graph + 4. recursively extract nodes from the graph that do not have incoming edges + 5. said recursive list is the ranking + """ + tallies, names = head_to_head_votes(ranks) + tallies = tallies - tallies.T + # print(tallies) + # note: the resulting tally matrix should be skew-symmetric + # order by strength of victory (using tideman's original method, don't think it would make a difference for us) + sorted_majorities = [] + for i in range(len(ranks[0])): + for j in range(len(ranks[i])): + if tallies[i, j] > 0: + sorted_majorities.append((i, j, tallies[i, j])) + # we don't explicitly deal with tied majorities here + sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True)) + # now do lock ins + lock_ins = [] + for (x, y, _) in sorted_majorities: + # invariant: lock_ins has no cycles here + lock_ins.append((x, y)) + # print("lock ins are now",np.array(lock_ins)) + if cycle_detect(np.array(lock_ins)): + # print("backup: cycle detected") + # if there's a cycle, delete the new addition and continue + lock_ins = lock_ins[:-1] + # now simply return all winners in order, and attach the losers + # to the back. This is because the overall loser might not be unique + # and (by concordance property) may never exist in any winning set to begin with. + # (otherwise he would either not be the loser, or cycles exist!) + # Since there could be multiple overall losers, we just return them in any order + # as we are unable to find a closer ranking + numerical_ranks = np.array(get_ranking(np.array(lock_ins))).astype(int) + conversion = [names[n] for n in numerical_ranks] + return conversion + + +if __name__ == "__main__": + ranks = ( + [("w", "x", "z", "y") for _ in range(1)] + + [("w", "y", "x", "z") for _ in range(2)] + # + [("x","y","z","w") for _ in range(4)] + + [("x", "z", "w", "y") for _ in range(5)] + + [("y", "w", "x", "z") for _ in range(1)] + # [("y","z","w","x") for _ in range(1000)] + ) + rp = ranked_pairs(ranks) + print(rp) diff --git a/text-frontend/auto_main.py b/text-frontend/auto_main.py new file mode 100644 index 00000000..cea07c1e --- /dev/null +++ b/text-frontend/auto_main.py @@ -0,0 +1,250 @@ +"""Simple REPL frontend.""" + +import http +import random + +import requests +import typer + +app = typer.Typer() + + +# debug constants +USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"} + + +def _random_message_id(): + return str(random.randint(1000, 9999)) + + +def _render_message(message: dict) -> str: + """Render a message to the user.""" + if message["is_assistant"]: + return f"Assistant: {message['text']}" + return f"Prompter: {message['text']}" + + +@app.command() +def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): + """automates tasks""" + + def _post(path: str, json: dict) -> dict: + response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key}) + response.raise_for_status() + if response.status_code == http.HTTPStatus.NO_CONTENT: + return None + return response.json() + + def gen_random_text(): + return " ".join([random.choice(["hello", "world", "foo", "bar"]) for _ in range(10)]) + + def gen_random_ranking(messages): + """rank messages randomly and return list of indexes in order of rank randomly""" + print("Ranking") + print(messages) + print(len(messages)) + ranks = [i for i in range(len(messages))] + shuffled = random.shuffle(ranks) + print(ranks) + print(shuffled) + return ranks + + tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})] + q = 0 + while tasks: + task = tasks.pop(0) + print(task) + + match (task["type"]): + case "initial_prompt": + typer.echo("Please provide an initial prompt to the assistant.") + if task["hint"]: + typer.echo(f"Hint: {task['hint']}") + # acknowledge task + message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + + prompt = gen_random_text() + user_message_id = _random_message_id() + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_message", + "message_id": message_id, + "task_id": task["id"], + "user_message_id": user_message_id, + "text": prompt, + "user": USER, + }, + ) + tasks.append(new_task) + + case "label_initial_prompt": + typer.echo("Label the following prompt:") + typer.echo(task["prompt"]) + # acknowledge task + message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + + valid_labels = task["valid_labels"] + + labels_dict = None + if task["mode"] == "simple" and len(valid_labels) == 1: + answer = random.choice([True, False]) + labels_dict = {valid_labels[0]: 1 if answer else 0} + else: + while labels_dict is None: + labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) + + if all([label in valid_labels for label in labels]): + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + else: + invalid_labels = [label for label in labels if label not in valid_labels] + typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") + + # send labels + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_labels", + "message_id": task["message_id"], + "task_id": task["id"], + "text": task["prompt"], + "labels": labels_dict, + "user": USER, + }, + ) + tasks.append(new_task) + case "prompter_reply": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_message", + "message_id": message_id, + "task_id": task["id"], + "user_message_id": user_message_id, + "text": gen_random_text(), + "user": USER, + }, + ) + tasks.append(new_task) + + case "assistant_reply": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_reply_to_message", + "message_id": message_id, + "task_id": task["id"], + "user_message_id": user_message_id, + "text": gen_random_text(), + "user": USER, + }, + ) + tasks.append(new_task) + + case "rank_prompter_replies" | "rank_assistant_replies": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + ranking = gen_random_ranking(task["replies"]) + print(ranking) + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "message_ranking", + "message_id": message_id, + "task_id": task["id"], + "ranking": ranking, + "user": USER, + }, + ) + tasks.append(new_task) + + case "rank_initial_prompts": + # acknowledge task + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + # send interaction + ranking = gen_random_ranking(task["prompots"]) + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "message_ranking", + "message_id": message_id, + "ranking": ranking, + "user": USER, + }, + ) + tasks.append(new_task) + + case "label_prompter_reply" | "label_assistant_reply": + # acknowledge task + typer.echo("Here is the conversation so far:") + for message in task["conversation"]["messages"]: + typer.echo(_render_message(message)) + + typer.echo("Label the following reply:") + typer.echo(task["reply"]) + message_id = _random_message_id() + user_message_id = _random_message_id() + _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id}) + valid_labels = task["valid_labels"] + + labels_dict = None + if task["mode"] == "simple" and len(valid_labels) == 1: + answer = random.choice([True, False]) + labels_dict = {valid_labels[0]: 1 if answer else 0} + else: + while labels_dict is None: + labels = random.sample(valid_labels, random.randint(1, len(valid_labels))) + + if all([label in valid_labels for label in labels]): + labels_dict = {label: "1" if label in labels else "0" for label in valid_labels} + else: + invalid_labels = [label for label in labels if label not in valid_labels] + typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}") + # send interaction + new_task = _post( + "/api/v1/tasks/interaction", + { + "type": "text_labels", + "message_id": task["message_id"], + "task_id": task["id"], + "text": task["reply"], + "labels": labels_dict, + "user": USER, + }, + ) + tasks.append(new_task) + case "task_done": + typer.echo("Task done!") + # rerun with new task slected from above cases + # add a new task + q += 1 + if q == 10: + typer.echo("Task done!") + break + tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})] + # + case _: + typer.echo(f"Unknown task type {task['type']}") + # rerun with new task slected from above cases + + +if __name__ == "__main__": + app()