## Monday, March 15, 2010

### Sampling a distribution

I was listening to an introductory talk about MCMC (link). I'll have more to say about the talk another time (hint: I think it's terrific). The speaker described a method for sampling from a probability distribution. The idea is to generate random samples from U[0,1], and then find the value of x at which the cumulative density function first exceeds the random number. The resulting samples of x are plotted as a histogram.

The Python script has three parts. In the first section we set up a weighted mixture of three normal distributions of different mean and sd, and plot that in red using x increments of 0.01. In the second section the same values are used to generate a discrete cdf for the same points. This is plotted in blue (after normalizing by the interval size).

Then we implement the algorithm for sampling. Binning for the histogram used the `collections.Counter` class that will be in Python 2.7. I don't have that, so I grabbed the class from here.

You might notice that the labels on the y-axis also cover the histogram. It would be better to have two separate plots, but I'm not quite sure how to handle that yet.

[UPDATE: more here]

 `import math, sysimport numpy as npimport matplotlib.pyplot as pltimport Counterdef normal(mu,sigma): def f(x): z = 1.0*(x-mu)/sigma e = math.e**(-0.5*z**2) C = math.sqrt(2*math.pi)*sigma return 1.0*e/C return fp1 = normal(0,2)p2 = normal(10,1)p3 = normal(18,0.5)# sum of weighted normal distributions@np.vectorizedef p(x): return 0.5*p1(x) + 0.25*p2(x) + 0.25*p3(x)dx = 0.01xmax = 25 R = np.arange(-10,xmax+dx,dx)# dashed linesplt.plot((R[0],R[-1]),(0,0),color='k',ls=':',lw=2)plt.plot((R[0],R[-1]),(1,1),color='k',ls=':',lw=2)plt.plot((R[0],R[-1]),(-0.5,-0.5),color='k',ls=':',lw=2)L = p(R)plt.plot(R,L,color='r',lw=3)#===========================================cdf = [L[0]]for e in L[1:]: cdf.append(cdf[-1] + e)cdf = np.array(cdf)cdf *= dxplt.plot(R,cdf,color='b',lw=3)ax = plt.axes()ax.set_xlim(-6,xmax)ax.set_ylim(-0.55,1.05)#===========================================def find_first(n,L): for i,e in enumerate(L): if n < e: return i return len(L)samples = list()width = 0.4f = 1/widthfor i in range(1000): n = np.random.random() # must adjust to actual range value = find_first(n,cdf)*dx - 10.0 # trick to truncate at fractional values samples.append(np.floor(f*value)/f)#samples = np.array(samples)c = Counter.Counter(samples)maxn = c.most_common(1)[0][1]for k in c: n = c[k] n = n * 0.45 / maxn r = plt.Rectangle((k,-0.5), width=width,height=n, facecolor='green') ax.add_patch(r)plt.savefig('example.png')`

writeonly said...

One small (potential) improvement to "find_first" would be to make it a binary
search, taking it from O(categories) to O(log categories).

daedrobu87 said...

Hi. I found this code very useful. Another improvement could be to use the numpy.where() function instead of creating an ennumerate and find the first value greater than certain value. For example, to find the values in a function that are greater than a low boundary:
---------------
t=np.arange(0,4*np.pi,pi/16)
cosineFunc=np.cos(t)
c=np.argwhere(cosineFunc>0.8)
plot(t,cosineFunc)
plot(t[c],cosineFunc[c],'ro')
---------------

Using built in function could bring a faster performance.