mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
import numpy as np
|
|
from scipy import log2
|
|
from scipy.integrate import nquad
|
|
from scipy.special import gammaln, psi
|
|
from scipy.stats import dirichlet
|
|
|
|
|
|
def make_range(*x):
|
|
"""
|
|
constructs leftover values for the simplex given the first k entries
|
|
(0,x_k) = 1-(x_1+...+x_(k-1))
|
|
"""
|
|
return (0, max(0, 1 - sum(x)))
|
|
|
|
|
|
def relative_entropy(p, q):
|
|
"""
|
|
relative entropy of the two given dirichlet distributions
|
|
"""
|
|
|
|
def tmp(*x):
|
|
"""
|
|
First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) )
|
|
Then computes the relative entropy of posterior and prior for that datapoint
|
|
"""
|
|
x_new = np.append(x, 1 - sum(x))
|
|
return p(x_new) * log2(p(x_new) / q(x_new))
|
|
|
|
return tmp
|
|
|
|
|
|
def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
|
|
s = np.random.rand(dim - 1, samples)
|
|
s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0)
|
|
# print(s)
|
|
pos = np.diff(s, axis=0)
|
|
# print(pos)
|
|
res = fun(pos)
|
|
return np.mean(res)
|
|
|
|
|
|
def analytic_solution(a_post, a_prior):
|
|
"""
|
|
Analytic solution to the KL-divergence between two dirichlet distributions.
|
|
Proof is in the Notion design doc.
|
|
"""
|
|
post_sum = np.sum(a_post)
|
|
prior_sum = np.sum(a_prior)
|
|
info = (
|
|
gammaln(post_sum)
|
|
- gammaln(prior_sum)
|
|
- np.sum(gammaln(a_post))
|
|
+ np.sum(gammaln(a_prior))
|
|
- np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum)))
|
|
)
|
|
|
|
return info
|
|
|
|
|
|
def infogain(a_post, a_prior):
|
|
raise (
|
|
"""For the love of good don't use this:
|
|
it's insanely poorly conditioned, the worst numerical code I have ever written
|
|
and it's slow as molasses. Use the analytic solution instead.
|
|
|
|
Maybe remove
|
|
"""
|
|
)
|
|
args = len(a_prior)
|
|
p = dirichlet(a_post).pdf
|
|
q = dirichlet(a_prior).pdf
|
|
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
|
|
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
|
|
return info
|
|
|
|
|
|
def uniform_expected_infogain(a_prior):
|
|
mean_weight = dirichlet.mean(a_prior)
|
|
print("weight", mean_weight)
|
|
results = []
|
|
for i, w in enumerate(mean_weight):
|
|
a_post = a_prior.copy()
|
|
a_post[i] = a_post[i] + 1
|
|
results.append(w * analytic_solution(a_post, a_prior))
|
|
return np.sum(results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
|
a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
|
|
|
print("algebraic", analytic_solution(a_post, a_prior))
|
|
# print("raw",infogain(a_post, a_prior))
|
|
print("large infogain", uniform_expected_infogain(a_prior))
|
|
print("post infogain", uniform_expected_infogain(a_post))
|
|
# a_prior = np.array([1,1,1000])
|
|
# print("small infogain",uniform_expected_infogain(a_prior))
|