From be424c96a2d158f91cf4a7a57df45e2af73d8d33 Mon Sep 17 00:00:00 2001 From: Alexander Mattick Date: Sun, 25 Dec 2022 11:45:29 +0100 Subject: [PATCH] infogain analytic solution --- backend/postprocessing/infogain_selector.py | 98 +++++++++++++++++++++ backend/requirements.txt | 1 + 2 files changed, 99 insertions(+) create mode 100644 backend/postprocessing/infogain_selector.py diff --git a/backend/postprocessing/infogain_selector.py b/backend/postprocessing/infogain_selector.py new file mode 100644 index 00000000..51f60fa7 --- /dev/null +++ b/backend/postprocessing/infogain_selector.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +import numpy as np +from scipy import log2 +from scipy.integrate import nquad +from scipy.special import gammaln, psi +from scipy.stats import dirichlet + + +def make_range(*x): + """ + constructs leftover values for the simplex given the first k entries + (0,x_k) = 1-(x_1+...+x_(k-1)) + """ + return (0, max(0, 1 - sum(x))) + + +def relative_entropy(p, q): + """ + relative entropy of the two given dirichlet distributions + """ + + def tmp(*x): + """ + First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) ) + Then computes the relative entropy of posterior and prior for that datapoint + """ + x_new = np.append(x, 1 - sum(x)) + return p(x_new) * log2(p(x_new) / q(x_new)) + + return tmp + + +def naive_monte_carlo_integral(fun, dim, samples=10_000_000): + s = np.random.rand(dim - 1, samples) + s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0) + # print(s) + pos = np.diff(s, axis=0) + # print(pos) + res = fun(pos) + return np.mean(res) + + +def analytic_solution(a_post, a_prior): + """ + Analytic solution to the KL-divergence between two dirichlet distributions. + Proof is in the Notion design doc. + """ + post_sum = np.sum(a_post) + prior_sum = np.sum(a_prior) + info = ( + gammaln(post_sum) + - gammaln(prior_sum) + - np.sum(gammaln(a_post)) + + np.sum(gammaln(a_prior)) + - np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum))) + ) + + return info + + +def infogain(a_post, a_prior): + raise ( + """For the love of good don't use this: + it's insanely poorly conditioned, the worst numerical code I have ever written + and it's slow as molasses. Use the analytic solution instead. + + Maybe remove + """ + ) + args = len(a_prior) + p = dirichlet(a_post).pdf + q = dirichlet(a_prior).pdf + (info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8}) + # info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post)) + return info + + +def uniform_expected_infogain(a_prior): + mean_weight = dirichlet.mean(a_prior) + print("weight", mean_weight) + results = [] + for i, w in enumerate(mean_weight): + a_post = a_prior.copy() + a_post[i] = a_post[i] + 1 + results.append(w * analytic_solution(a_post, a_prior)) + return np.sum(results) + + +if __name__ == "__main__": + a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + print("algebraic", analytic_solution(a_post, a_prior)) + # print("raw",infogain(a_post, a_prior)) + print("large infogain", uniform_expected_infogain(a_prior)) + print("post infogain", uniform_expected_infogain(a_post)) + # a_prior = np.array([1,1,1000]) + # print("small infogain",uniform_expected_infogain(a_prior)) diff --git a/backend/requirements.txt b/backend/requirements.txt index b882d594..dd11aa18 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,6 +5,7 @@ numpy==1.22.4 psycopg2-binary==2.9.5 pydantic==1.9.1 python-dotenv==0.21.0 +scipy==1.8.1 SQLAlchemy==1.4.41 sqlmodel==0.0.8 starlette==0.22.0