Wednesday, October 27, 2010

More on sampling a distribution

Some time ago I had a post about sampling from a distribution. More specifically, the problem is to generate random samples given a cdf (cumulative distribution function). There have been a couple of useful comments, and I'd like to extend the explanation as well.

Previously, a function was defined that can be used to generate a probability for any value of a random variable x, from a probability density function or distribution (pdf) with a given mean and standard deviation. We remember that for a continuous distribution, the actual probability at any discrete x is zero, since the total number of possible x's is infinite. Technically the definition is that the pdf is the derivative of the cdf, which I misremembered as the cumulative density function. So the cdf(x) is the area under the pdf from negative infinity to x.

One nice thing about this approach is it is then easy to define a sum of weighted distributions.

We get what looks like a smooth curve by plotting a (relatively) large number of points (in this example, 3502). The plotting uses matplotlib (see this post referencing set-up on the Mac). The cdf is computed by simply accumulating values from the pdf. Normalization is usually done by dividing by the total, but the method I showed was just slightly more subtle:

cdf *= dx
.

Unfortunately, it is also wrong! (Sorry). I won't try to explain what I was thinking, but the fact that it gave an accurate value (as shown by the maximum of the cdf being equal to 1) is an accident, and you should instead simply divide by the sum of the values, and moreover, do the operation on the pdf before constructing the cdf:

pdf /= sum(pdf)


This change to the pdf means we need to magnify it before plotting, and really, should provide a different y-axis on the right hand side, with the true values. The left-hand y-axis only is accurate only for the cdf.

The idea for sampling is to generate a random float using np.random.random(), and then ask which of the values in the cdf (which by definition are ordered), first exceeds this value. The indexes resulting from repeating this procedure are concentrated in the steep part of the cdf (the peaks of the pdf), because the probability that a given position in our "discretized" form of the cdf satisfies this relationship is proportional to the slope of the cdf curve (the added vertical distance between a given index and the one previous).

As a reader suggested, an improvement to the code is to recognize that the list we're searching (the cdf) is ordered and so can more efficiently be searched using a binary search. The code is a little tricky to write, so I skipped it last time. And luckily, Python comes with "batteries included" and for this application what we want is the bisect module from the standard library. The example find_le function's docstring says: 'Find rightmost value less than or equal to x'. I just modified this code (which calls bisect.bisect_right) to return the index rather than the value.

We're interested in intervals where the slope is steepest. The original find_first function returned the index of the right-hand value, while this function will return the index of the left-hand value. I suppose either one is fine, but perhaps it would be better to use the midpoint.

There are a few more steps in the code that are a little obscure, including bins with fractional width, and the use of the Counter class to organize the data for the histogram. But this post is getting a bit long so I'll skip them for now.



Modified code:


import math, sys, bisect
import numpy as np
import matplotlib.pyplot as plt
import Counter

def normal(mu,sigma):
def f(x):
z = 1.0*(x-mu)/sigma
e = math.e**(-0.5*z**2)
C = math.sqrt(2*math.pi)*sigma
return 1.0*e/C
return f

p1 = normal(0,2)
p2 = normal(10,1)
p3 = normal(18,0.5)

# sum of weighted normal distributions
@np.vectorize
def p(x):
return 0.5*p1(x) + 0.25*p2(x) + 0.25*p3(x)

dx = 0.01
xmax = 25
R = np.arange(-10,xmax+dx,dx)
# dashed lines
plt.plot((R[0],R[-1]),(0,0),color='k',ls=':',lw=2)
plt.plot((R[0],R[-1]),(1,1),color='k',ls=':',lw=2)
plt.plot((R[0],R[-1]),(-0.5,-0.5),color='k',lw=2)

pdf = p(R)
S = sum(pdf)
pdf /= S
plt.plot(R,pdf*(len(R)/10.0),color='r',lw=3)
#print len(R)
S = sum(pdf)
print S
#===========================================
cdf = [pdf[0]]
for e in pdf[1:]:
cdf.append(cdf[-1] + e)
cdf = np.array(cdf)
#cdf /= S
plt.plot(R,cdf,color='b',lw=3)

ax = plt.axes()
ax.set_xlim(-6,xmax)
ax.set_ylim(-0.55,1.05)
#===========================================
def find_le(n,L):
'modified from bisect.find_le'
i = bisect.bisect_right(L,n)
if i:
return i
raise ValueError

samples = list()
width = 0.2
f = 1/width
for i in range(10000):
n = np.random.random()
# must adjust to actual range
value = find_le(n,cdf)*dx - 10.0
# trick to truncate at fractional values
samples.append(np.floor(f*value)/f)
#===========================================
c = Counter.Counter(samples)
maxn = c.most_common(1)[0][1]
for k in c:
n = c[k]
n = n * 0.45 / maxn
r = plt.Rectangle((k,-0.5),
width=width,height=n,
facecolor='magenta')
ax.add_patch(r)
plt.savefig('example.png')

1 comment:

Conrad Lee said...

Are you saying that the code in your previous code is incorrect? If so, you should probably update the incorrect post to say that it's wrong. As it stands now, it's possible that most people visiting the incorrect code will never know it's incorrect.