mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Implement task selection (#383)
* commented out legacy numerical solver * added comments and task_scheduling for selecting which task to serve to users * removed standalone task weighting * pre-commit hook rerun Co-authored-by: Alexander Mattick <alex.mattick@fau.de>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import numpy as np
|
||||
from scipy import log2
|
||||
from scipy.integrate import nquad
|
||||
from scipy.special import gammaln, psi
|
||||
from scipy.stats import dirichlet
|
||||
|
||||
'''
|
||||
Legacy numerical solution.
|
||||
Should not be used as it is probably broken
|
||||
|
||||
|
||||
def make_range(*x):
|
||||
"""
|
||||
@@ -38,6 +40,23 @@ def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
|
||||
res = fun(pos)
|
||||
return np.mean(res)
|
||||
|
||||
def infogain(a_post, a_prior):
|
||||
raise (
|
||||
"""For the love of good don't use this:
|
||||
it's insanely poorly conditioned, the worst numerical code I have ever written
|
||||
and it's slow as molasses. Use the analytic solution instead.
|
||||
|
||||
Maybe remove
|
||||
"""
|
||||
)
|
||||
args = len(a_prior)
|
||||
p = dirichlet(a_post).pdf
|
||||
q = dirichlet(a_prior).pdf
|
||||
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
|
||||
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
|
||||
return info
|
||||
'''
|
||||
|
||||
|
||||
def analytic_solution(a_post, a_prior):
|
||||
"""
|
||||
@@ -57,26 +76,8 @@ def analytic_solution(a_post, a_prior):
|
||||
return info
|
||||
|
||||
|
||||
def infogain(a_post, a_prior):
|
||||
raise (
|
||||
"""For the love of good don't use this:
|
||||
it's insanely poorly conditioned, the worst numerical code I have ever written
|
||||
and it's slow as molasses. Use the analytic solution instead.
|
||||
|
||||
Maybe remove
|
||||
"""
|
||||
)
|
||||
args = len(a_prior)
|
||||
p = dirichlet(a_post).pdf
|
||||
q = dirichlet(a_prior).pdf
|
||||
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
|
||||
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
|
||||
return info
|
||||
|
||||
|
||||
def uniform_expected_infogain(a_prior):
|
||||
mean_weight = dirichlet.mean(a_prior)
|
||||
print("weight", mean_weight)
|
||||
results = []
|
||||
for i, w in enumerate(mean_weight):
|
||||
a_post = a_prior.copy()
|
||||
|
||||
@@ -87,8 +87,9 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
|
||||
"""
|
||||
This function returns the gain of points for a given prompt's votes
|
||||
|
||||
This function is only to be run when archiving a question
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
|
||||
In contrast to the other score updating functions, we can run this online as new votes come in.
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information.
|
||||
|
||||
|
||||
Parameters:
|
||||
consensus (ArrayLike): all votes cast for this question
|
||||
@@ -100,7 +101,8 @@ def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
|
||||
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
|
||||
# since 100 is the lowest, 300 the highest and 200 the middle value
|
||||
consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1
|
||||
delta_votes = np.sum(consensus_ranking * consensus)
|
||||
# expected consenus ranking (i.e. normalize the votes and multiply-sum with weightings)
|
||||
delta_votes = np.sum(consensus_ranking * consensus / sum(consensus))
|
||||
new_points = delta_votes + voter_data.prompt_points
|
||||
|
||||
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
|
||||
@@ -133,7 +135,7 @@ def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.Arr
|
||||
"research design and statistical analyses, second edition, 2003"
|
||||
the authors note that at least from an significance test POV they will yield the same p-values
|
||||
|
||||
Parameters:
|
||||
Parameters:
|
||||
user_ranking (ArrayLike): ranking produced by the user
|
||||
consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking
|
||||
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
from scipy import optimize
|
||||
|
||||
|
||||
class Task(Enum):
|
||||
RANKING = 0
|
||||
ANSWER = 1
|
||||
PROMPT = 2
|
||||
VOTE = 3
|
||||
|
||||
|
||||
def task_selection(
|
||||
num_ranking_tasks: int, current_prompts: int, target_num_prompts: int, p: float, answers_per_prompt: int
|
||||
) -> Task:
|
||||
"""
|
||||
This computes which task to serve to the user.
|
||||
In general, this method aims to get rankable tasks out of the active pool ASAP.
|
||||
Before checking anything else, we first have a p% probability of running a ranking task.
|
||||
After that, we can dynamically determine which task to serve by balancing the number of active tasks.
|
||||
|
||||
Parameters:
|
||||
num_ranking_tasks (int): number of prompts that are ready to do ranking (i.e. have "answers_per_prompt" many answers)
|
||||
current_prompts (int): how many prompts are currently in the active pool
|
||||
target_num_prompts (int): how many prompts _should_ be in the active pool
|
||||
p (float): probability to serve a ranking task, if one is available
|
||||
answers_per_prompt (int): number of answers we want to have per prompt
|
||||
Returns:
|
||||
task (Task): the task Enum that corresponds to one of the four tasks
|
||||
"""
|
||||
if num_ranking_tasks > 0 and np.random.rand() < p:
|
||||
return Task.RANKING
|
||||
rate = 50 / (current_prompts * 2)
|
||||
prob_prompt_task = 0.5 + (target_num_prompts - current_prompts) * rate
|
||||
# Yes, I'm too lazy to solve this analytically...
|
||||
prob_unfinished_prompt = optimize.linprog(
|
||||
np.array([1, 1]), A_eq=np.array([[1, 1], [1, -answers_per_prompt]]), b_eq=np.array([1, 0]), bounds=(0, None)
|
||||
).x[0]
|
||||
if np.random.rand() < prob_prompt_task:
|
||||
if np.random.rand() < prob_unfinished_prompt:
|
||||
return Task.ANSWER
|
||||
else:
|
||||
return Task.PROMPT
|
||||
else:
|
||||
return Task.VOTE
|
||||
|
||||
|
||||
def next_answer_task(possible_prompts, answers_per_prompt):
|
||||
"""
|
||||
If the `task_selection`method returns "answer", you can use this method to decide which
|
||||
prompt should get an answer next.
|
||||
The goal of this is to finish off the prompts that have almost enough answers collected already:
|
||||
I.e. if we want 5 answers, this is going to give preferential sampling to those prompts that already
|
||||
have 4/5 answers.
|
||||
This helps to not have too much close-to-finished prompts in the active set.
|
||||
|
||||
Parameters:
|
||||
possible_prompts (dict[prompt_id, num_answers]): a dictonary containing all open prompts and the number of answers these prompts currently have.
|
||||
answers_per_prompt (int): number of answers we per prompt to target
|
||||
Returns:
|
||||
prompt_id (int): the prompt_id corresponding to the next prompt that should get a new answer
|
||||
"""
|
||||
nums = list(set(possible_prompts.values()))
|
||||
p = np.array([max(x / answers_per_prompt, 1 / answers_per_prompt) for x in nums])
|
||||
idx = np.random.choice(nums, p=p / p.sum())
|
||||
sample = np.random.choice([k for k, v in possible_prompts.items() if v == idx])
|
||||
return sample
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = task_selection(1, 500, 1000, 0.1, 5)
|
||||
print(x)
|
||||
y = next_answer_task({"this": 2, "is": 4, "a": 1, "test": 4}, 5)
|
||||
print(y)
|
||||
Reference in New Issue
Block a user