Monday, March 15, 2010

Sampling a distribution


I was listening to an introductory talk about MCMC (link). I'll have more to say about the talk another time (hint: I think it's terrific). The speaker described a method for sampling from a probability distribution. The idea is to generate random samples from U[0,1], and then find the value of x at which the cumulative density function first exceeds the random number. The resulting samples of x are plotted as a histogram.

The Python script has three parts. In the first section we set up a weighted mixture of three normal distributions of different mean and sd, and plot that in red using x increments of 0.01. In the second section the same values are used to generate a discrete cdf for the same points. This is plotted in blue (after normalizing by the interval size).

Then we implement the algorithm for sampling. Binning for the histogram used the collections.Counter class that will be in Python 2.7. I don't have that, so I grabbed the class from here.

You might notice that the labels on the y-axis also cover the histogram. It would be better to have two separate plots, but I'm not quite sure how to handle that yet.

[UPDATE: more here]


import math, sys
import numpy as np
import matplotlib.pyplot as plt
import Counter

def normal(mu,sigma):
def f(x):
z = 1.0*(x-mu)/sigma
e = math.e**(-0.5*z**2)
C = math.sqrt(2*math.pi)*sigma
return 1.0*e/C
return f

p1 = normal(0,2)
p2 = normal(10,1)
p3 = normal(18,0.5)

# sum of weighted normal distributions
@np.vectorize
def p(x):
return 0.5*p1(x) + 0.25*p2(x) + 0.25*p3(x)

dx = 0.01
xmax = 25
R = np.arange(-10,xmax+dx,dx)
# dashed lines
plt.plot((R[0],R[-1]),(0,0),color='k',ls=':',lw=2)
plt.plot((R[0],R[-1]),(1,1),color='k',ls=':',lw=2)
plt.plot((R[0],R[-1]),(-0.5,-0.5),color='k',ls=':',lw=2)
L = p(R)
plt.plot(R,L,color='r',lw=3)
#===========================================
cdf = [L[0]]
for e in L[1:]:
cdf.append(cdf[-1] + e)
cdf = np.array(cdf)
cdf *= dx
plt.plot(R,cdf,color='b',lw=3)

ax = plt.axes()
ax.set_xlim(-6,xmax)
ax.set_ylim(-0.55,1.05)
#===========================================
def find_first(n,L):
for i,e in enumerate(L):
if n < e: return i
return len(L)

samples = list()
width = 0.4
f = 1/width
for i in range(1000):
n = np.random.random()
# must adjust to actual range
value = find_first(n,cdf)*dx - 10.0
# trick to truncate at fractional values
samples.append(np.floor(f*value)/f)

#samples = np.array(samples)
c = Counter.Counter(samples)
maxn = c.most_common(1)[0][1]
for k in c:
n = c[k]
n = n * 0.45 / maxn
r = plt.Rectangle((k,-0.5),
width=width,height=n,
facecolor='green')
ax.add_patch(r)
plt.savefig('example.png')

2 comments:

writeonly said...

One small (potential) improvement to "find_first" would be to make it a binary
search, taking it from O(categories) to O(log categories).

daedrobu87 said...

Hi. I found this code very useful. Another improvement could be to use the numpy.where() function instead of creating an ennumerate and find the first value greater than certain value. For example, to find the values in a function that are greater than a low boundary:
---------------
t=np.arange(0,4*np.pi,pi/16)
cosineFunc=np.cos(t)
c=np.argwhere(cosineFunc>0.8)
plot(t,cosineFunc)
plot(t[c],cosineFunc[c],'ro')
---------------

Using built in function could bring a faster performance.