Implement task selection (#383)

* commented out legacy numerical solver

* added comments and task_scheduling for selecting which task to serve to users

* removed standalone task weighting

* pre-commit hook rerun

Co-authored-by: Alexander Mattick <alex.mattick@fau.de>
This commit is contained in:
MattAlexMiracle
2023-01-05 16:14:39 +01:00
committed by GitHub
parent 894219423d
commit 3dbe0ae1ba
3 changed files with 102 additions and 24 deletions
+21 -20
View File
@@ -1,9 +1,11 @@
import numpy as np
from scipy import log2
from scipy.integrate import nquad
from scipy.special import gammaln, psi
from scipy.stats import dirichlet
'''
Legacy numerical solution.
Should not be used as it is probably broken
def make_range(*x):
"""
@@ -38,6 +40,23 @@ def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
res = fun(pos)
return np.mean(res)
def infogain(a_post, a_prior):
raise (
"""For the love of good don't use this:
it's insanely poorly conditioned, the worst numerical code I have ever written
and it's slow as molasses. Use the analytic solution instead.
Maybe remove
"""
)
args = len(a_prior)
p = dirichlet(a_post).pdf
q = dirichlet(a_prior).pdf
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
return info
'''
def analytic_solution(a_post, a_prior):
"""
@@ -57,26 +76,8 @@ def analytic_solution(a_post, a_prior):
return info
def infogain(a_post, a_prior):
raise (
"""For the love of good don't use this:
it's insanely poorly conditioned, the worst numerical code I have ever written
and it's slow as molasses. Use the analytic solution instead.
Maybe remove
"""
)
args = len(a_prior)
p = dirichlet(a_post).pdf
q = dirichlet(a_prior).pdf
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
return info
def uniform_expected_infogain(a_prior):
mean_weight = dirichlet.mean(a_prior)
print("weight", mean_weight)
results = []
for i, w in enumerate(mean_weight):
a_post = a_prior.copy()