In [1]:
import numpy as np
import matplotlib.pyplot as plt
import math
import scipy.special as spsp
import scipy.stats as spst
%matplotlib inline
import sympy
Let’s use a uniform distribution between 0 and 1 as the proposal distribution to generate the samples from our target distribution:
$f(x)=20x(1-x)^3, 0\leq x<1$
In [5]:
#Compute C value
x=np.linspace(0,1,100000)
pdf_t=20*x*(1-x)**3
pdf_p=1
plt.plot(x,pdf_t/pdf_p)
plt.show()
np.max(x,pdf_t/pdf_p)

---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
6 plt.plot(x,pdf_t/pdf_p)
7 plt.show()
—-> 8 np.max(x,pdf_t/pdf_p)
<__array_function__ internals> in amax(*args, **kwargs)
/usr/lib/python3.9/site-packages/numpy/core/fromnumeric.py in amax(a, axis, out, keepdims, initial, where)
2703 5
2704 “””
-> 2705 return _wrapreduction(a, np.maximum, ‘max’, axis, None, out,
2706 keepdims=keepdims, initial=initial, where=where)
2707
/usr/lib/python3.9/site-packages/numpy/core/fromnumeric.py in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs)
85 return reduction(axis=axis, out=out, **passkwargs)
86
—> 87 return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
88
89
TypeError: only integer scalar arrays can be converted to a scalar index
In [14]:
#analytical
# find the maximum
x=sympy.Symbol(“x”)
pdf_t=20*x*(1-x)**3
pdf_p=1
ratio=pdf_t/pdf_p
sympy.solve(sympy.diff(ratio,x),x)
# once find the higest ratio, plug into
ratio.subs({x:1/4})
# second order negative, we got maximum
sympy.diff(ratio,x,x).subs({x:1/4})
c=float(ratio.subs({x:1/4}))
# you need to specifically pay attention to the type of value
In [15]:
#Rejection sampling
def pdf_t(x):
pdf=20*x*(1-x)**3
return pdf
def Rejection():
#generate a proposal
prop=np.random.rand()
while np.random.rand()>pdf_t(prop)/1/c:
prop=np.random.rand()
return prop
Rejection()
Out[15]:
0.7360262585475251
In [16]:
#visualization
samples=np.array([Rejection() for i in range(1000)])
plt.hist(samples,bins=30,density=True)
x_lins=np.linspace(0,1,1000)
pdf=20*x_lins*(1-x_lins)**3
plt.plot(x_lins,pdf)
plt.show()

In [0]: