mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
infogain analytic solution
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
from scipy import log2
|
||||
from scipy.integrate import nquad
|
||||
from scipy.special import gammaln, psi
|
||||
from scipy.stats import dirichlet
|
||||
|
||||
|
||||
def make_range(*x):
|
||||
"""
|
||||
constructs leftover values for the simplex given the first k entries
|
||||
(0,x_k) = 1-(x_1+...+x_(k-1))
|
||||
"""
|
||||
return (0, max(0, 1 - sum(x)))
|
||||
|
||||
|
||||
def relative_entropy(p, q):
|
||||
"""
|
||||
relative entropy of the two given dirichlet distributions
|
||||
"""
|
||||
|
||||
def tmp(*x):
|
||||
"""
|
||||
First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) )
|
||||
Then computes the relative entropy of posterior and prior for that datapoint
|
||||
"""
|
||||
x_new = np.append(x, 1 - sum(x))
|
||||
return p(x_new) * log2(p(x_new) / q(x_new))
|
||||
|
||||
return tmp
|
||||
|
||||
|
||||
def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
|
||||
s = np.random.rand(dim - 1, samples)
|
||||
s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0)
|
||||
# print(s)
|
||||
pos = np.diff(s, axis=0)
|
||||
# print(pos)
|
||||
res = fun(pos)
|
||||
return np.mean(res)
|
||||
|
||||
|
||||
def analytic_solution(a_post, a_prior):
|
||||
"""
|
||||
Analytic solution to the KL-divergence between two dirichlet distributions.
|
||||
Proof is in the Notion design doc.
|
||||
"""
|
||||
post_sum = np.sum(a_post)
|
||||
prior_sum = np.sum(a_prior)
|
||||
info = (
|
||||
gammaln(post_sum)
|
||||
- gammaln(prior_sum)
|
||||
- np.sum(gammaln(a_post))
|
||||
+ np.sum(gammaln(a_prior))
|
||||
- np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum)))
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def infogain(a_post, a_prior):
|
||||
raise (
|
||||
"""For the love of good don't use this:
|
||||
it's insanely poorly conditioned, the worst numerical code I have ever written
|
||||
and it's slow as molasses. Use the analytic solution instead.
|
||||
|
||||
Maybe remove
|
||||
"""
|
||||
)
|
||||
args = len(a_prior)
|
||||
p = dirichlet(a_post).pdf
|
||||
q = dirichlet(a_prior).pdf
|
||||
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
|
||||
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
|
||||
return info
|
||||
|
||||
|
||||
def uniform_expected_infogain(a_prior):
|
||||
mean_weight = dirichlet.mean(a_prior)
|
||||
print("weight", mean_weight)
|
||||
results = []
|
||||
for i, w in enumerate(mean_weight):
|
||||
a_post = a_prior.copy()
|
||||
a_post[i] = a_post[i] + 1
|
||||
results.append(w * analytic_solution(a_post, a_prior))
|
||||
return np.sum(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
|
||||
print("algebraic", analytic_solution(a_post, a_prior))
|
||||
# print("raw",infogain(a_post, a_prior))
|
||||
print("large infogain", uniform_expected_infogain(a_prior))
|
||||
print("post infogain", uniform_expected_infogain(a_post))
|
||||
# a_prior = np.array([1,1,1000])
|
||||
# print("small infogain",uniform_expected_infogain(a_prior))
|
||||
@@ -5,6 +5,7 @@ numpy==1.22.4
|
||||
psycopg2-binary==2.9.5
|
||||
pydantic==1.9.1
|
||||
python-dotenv==0.21.0
|
||||
scipy==1.8.1
|
||||
SQLAlchemy==1.4.41
|
||||
sqlmodel==0.0.8
|
||||
starlette==0.22.0
|
||||
|
||||
Reference in New Issue
Block a user