mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add changes for auto_main, tree_manager, and utils/ranking (#786)
* add changes for auto_main, tree_manager, and utils/ranking * pre-commit changes Co-authored-by: Alexander Mattick <alex.mattick@fau.de>
This commit is contained in:
@@ -13,6 +13,7 @@ from oasst_backend.models import Message, MessageReaction, MessageTreeState, Tas
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.sql import text
|
||||
@@ -587,6 +588,7 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
mts: MessageTreeState
|
||||
@@ -603,8 +605,24 @@ class TreeManager:
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
self.update_message_ranks(rankings_by_message)
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_message_ranks(self, rankings_by_message: Dict[int, int]) -> None:
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
sorted_messages = []
|
||||
for msg_reaction in ranking:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
for rank, uuid in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.db.query(Message).filter(Message.id == uuid).one()
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def head_to_head_votes(ranks: List[List[int]]):
|
||||
tallies = np.zeros((len(ranks[0]), len(ranks[0])))
|
||||
names = sorted(ranks[0])
|
||||
ranks = np.array(ranks)
|
||||
# we want the sorted indices
|
||||
ranks = np.argsort(ranks, axis=1)
|
||||
for i in range(ranks.shape[1]):
|
||||
for j in range(i + 1, ranks.shape[1]):
|
||||
# now count the cases someone voted for i over j
|
||||
over_j = np.sum(ranks[:, i] < ranks[:, j])
|
||||
over_i = np.sum(ranks[:, j] < ranks[:, i])
|
||||
tallies[i, j] = over_j
|
||||
# tallies[i,j] = over_i
|
||||
tallies[j, i] = over_i
|
||||
# tallies[j,i] = over_j
|
||||
return tallies, names
|
||||
|
||||
|
||||
def cycle_detect(pairs):
|
||||
"""Recursively detect cylces by removing condorcet losers until either only one pair is left or condorcet loosers no longer exist
|
||||
This method upholds the invariant that in a ranking for all a,b either a>b or b>a for all a,b.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : False if the pairs do not contain a cycle, True if the pairs contain a cycle
|
||||
|
||||
|
||||
"""
|
||||
# get all condorcet losers (pairs that loose to all other pairs)
|
||||
# idea: filter all losers that are never winners
|
||||
# print("pairs", pairs)
|
||||
if len(pairs) <= 1:
|
||||
return False
|
||||
losers = [c_lose for c_lose in np.unique(pairs[:, 1]) if c_lose not in pairs[:, 0]]
|
||||
if len(losers) == 0:
|
||||
# if we recursively removed pairs, and at some point we did not have
|
||||
# a condorcet loser, that means everything is both a winner and loser,
|
||||
# yielding at least one (winner,loser), (loser,winner) pair
|
||||
return True
|
||||
|
||||
new = []
|
||||
for p in pairs:
|
||||
if p[1] not in losers:
|
||||
new.append(p)
|
||||
return cycle_detect(np.array(new))
|
||||
|
||||
|
||||
def get_winner(pairs):
|
||||
"""
|
||||
This returns _one_ concordant winner.
|
||||
It could be that there are multiple concordant winners, but in our case
|
||||
since we are interested in a ranking, we have to choose one at random.
|
||||
"""
|
||||
losers = np.unique(pairs[:, 1]).astype(int)
|
||||
winners = np.unique(pairs[:, 0]).astype(int)
|
||||
for w in winners:
|
||||
if w not in losers:
|
||||
return w
|
||||
|
||||
|
||||
def get_ranking(pairs):
|
||||
"""
|
||||
Abuses concordance property to get a (not necessarily unqiue) ranking.
|
||||
The lack of uniqueness is due to the potential existence of multiple
|
||||
equally ranked winners. We have to pick one, which is where
|
||||
the non-uniqueness comes from
|
||||
"""
|
||||
if len(pairs) == 1:
|
||||
return list(pairs[0])
|
||||
w = get_winner(pairs)
|
||||
# now remove the winner from the list of pairs
|
||||
p_new = np.array([(a, b) for a, b in pairs if a != w])
|
||||
return [w] + get_ranking(p_new)
|
||||
|
||||
|
||||
def ranked_pairs(ranks: List[List[int]]):
|
||||
"""
|
||||
Expects a list of rankings for an item like:
|
||||
[("w","x","z","y") for _ in range(3)]
|
||||
+ [("w","y","x","z") for _ in range(2)]
|
||||
+ [("x","y","z","w") for _ in range(4)]
|
||||
+ [("x","z","w","y") for _ in range(5)]
|
||||
+ [("y","w","x","z") for _ in range(1)]
|
||||
This code is quite brain melting, but the idea is the following:
|
||||
1. create a head-to-head matrix that tallies up all win-lose combinations of preferences
|
||||
2. take all combinations that win more than they loose and sort those by how often they win
|
||||
3. use that to create an (implicit) directed graph
|
||||
4. recursively extract nodes from the graph that do not have incoming edges
|
||||
5. said recursive list is the ranking
|
||||
"""
|
||||
tallies, names = head_to_head_votes(ranks)
|
||||
tallies = tallies - tallies.T
|
||||
# print(tallies)
|
||||
# note: the resulting tally matrix should be skew-symmetric
|
||||
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
|
||||
sorted_majorities = []
|
||||
for i in range(len(ranks[0])):
|
||||
for j in range(len(ranks[i])):
|
||||
if tallies[i, j] > 0:
|
||||
sorted_majorities.append((i, j, tallies[i, j]))
|
||||
# we don't explicitly deal with tied majorities here
|
||||
sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True))
|
||||
# now do lock ins
|
||||
lock_ins = []
|
||||
for (x, y, _) in sorted_majorities:
|
||||
# invariant: lock_ins has no cycles here
|
||||
lock_ins.append((x, y))
|
||||
# print("lock ins are now",np.array(lock_ins))
|
||||
if cycle_detect(np.array(lock_ins)):
|
||||
# print("backup: cycle detected")
|
||||
# if there's a cycle, delete the new addition and continue
|
||||
lock_ins = lock_ins[:-1]
|
||||
# now simply return all winners in order, and attach the losers
|
||||
# to the back. This is because the overall loser might not be unique
|
||||
# and (by concordance property) may never exist in any winning set to begin with.
|
||||
# (otherwise he would either not be the loser, or cycles exist!)
|
||||
# Since there could be multiple overall losers, we just return them in any order
|
||||
# as we are unable to find a closer ranking
|
||||
numerical_ranks = np.array(get_ranking(np.array(lock_ins))).astype(int)
|
||||
conversion = [names[n] for n in numerical_ranks]
|
||||
return conversion
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ranks = (
|
||||
[("w", "x", "z", "y") for _ in range(1)]
|
||||
+ [("w", "y", "x", "z") for _ in range(2)]
|
||||
# + [("x","y","z","w") for _ in range(4)]
|
||||
+ [("x", "z", "w", "y") for _ in range(5)]
|
||||
+ [("y", "w", "x", "z") for _ in range(1)]
|
||||
# [("y","z","w","x") for _ in range(1000)]
|
||||
)
|
||||
rp = ranked_pairs(ranks)
|
||||
print(rp)
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
import http
|
||||
import random
|
||||
|
||||
import requests
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# debug constants
|
||||
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
|
||||
|
||||
|
||||
def _random_message_id():
|
||||
return str(random.randint(1000, 9999))
|
||||
|
||||
|
||||
def _render_message(message: dict) -> str:
|
||||
"""Render a message to the user."""
|
||||
if message["is_assistant"]:
|
||||
return f"Assistant: {message['text']}"
|
||||
return f"Prompter: {message['text']}"
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""automates tasks"""
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
if response.status_code == http.HTTPStatus.NO_CONTENT:
|
||||
return None
|
||||
return response.json()
|
||||
|
||||
def gen_random_text():
|
||||
return " ".join([random.choice(["hello", "world", "foo", "bar"]) for _ in range(10)])
|
||||
|
||||
def gen_random_ranking(messages):
|
||||
"""rank messages randomly and return list of indexes in order of rank randomly"""
|
||||
print("Ranking")
|
||||
print(messages)
|
||||
print(len(messages))
|
||||
ranks = [i for i in range(len(messages))]
|
||||
shuffled = random.shuffle(ranks)
|
||||
print(ranks)
|
||||
print(shuffled)
|
||||
return ranks
|
||||
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
q = 0
|
||||
while tasks:
|
||||
task = tasks.pop(0)
|
||||
print(task)
|
||||
|
||||
match (task["type"]):
|
||||
case "initial_prompt":
|
||||
typer.echo("Please provide an initial prompt to the assistant.")
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
prompt = gen_random_text()
|
||||
user_message_id = _random_message_id()
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": prompt,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "label_initial_prompt":
|
||||
typer.echo("Label the following prompt:")
|
||||
typer.echo(task["prompt"])
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
valid_labels = task["valid_labels"]
|
||||
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["prompt"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "prompter_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "assistant_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "rank_prompter_replies" | "rank_assistant_replies":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["replies"])
|
||||
print(ranking)
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "rank_initial_prompts":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["prompots"])
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "label_prompter_reply" | "label_assistant_reply":
|
||||
# acknowledge task
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
|
||||
typer.echo("Label the following reply:")
|
||||
typer.echo(task["reply"])
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
valid_labels = task["valid_labels"]
|
||||
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["reply"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "task_done":
|
||||
typer.echo("Task done!")
|
||||
# rerun with new task slected from above cases
|
||||
# add a new task
|
||||
q += 1
|
||||
if q == 10:
|
||||
typer.echo("Task done!")
|
||||
break
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
#
|
||||
case _:
|
||||
typer.echo(f"Unknown task type {task['type']}")
|
||||
# rerun with new task slected from above cases
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
Reference in New Issue
Block a user