6. Dirichlet and Guassian Mixture Models

This notebook shows how to learn the distribution of a data and then simulate samples from that learned distribution.

6.1. Motivating example

Let’s say we observe a random variable \(X\). When we plot the distribution of \(X\), we observe something like the following. Here, we know that the observations of \(X\) actually came from a mixure of gaussians (but pretend we did not know).

  • \(X_0 \sim \mathcal{N}(0, 1)\)

  • \(X_1 \sim \mathcal{N}(5, 1)\)

  • \(X_2 \sim \mathcal{N}(10, 2)\)

%matplotlib inline
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import random


x0 = np.random.normal(0, 1, 500)
x1 = np.random.normal(5, 1, 1000)
x2 = np.random.normal(10, 2, 300)
X = np.concatenate([x0, x1, x2]).reshape(-1, 1)

_ = sns.distplot(X)

6.2. k-means clustering (KMC)

We apply k-means clustering to \(X\) to see where the clusters are for k=2, 3, 4, 5 and 6. For each k, we evaluate the clustering results using the silhouette score, choosing the k with the largest silhouette score. In this case, k=3 has the highest silhouette score.

from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

def get_silhouette_score(X, k):
    print('starting k={}'.format(k))
    model = KMeans(k, random_state=37, n_jobs=-1)
    labels = model.predict(X)
    score = silhouette_score(X, labels)
    return score

kms_scores = sorted([(k, get_silhouette_score(X, k)) for k in range(2, 6)], key=lambda tup: (-tup[1], tup[0]))
starting k=2
starting k=3
starting k=4
starting k=5
[(3, 0.7068431322753701),
 (2, 0.6343853947741563),
 (4, 0.6067295633567155),
 (5, 0.5721530284698704)]

6.3. Gaussian Mixture Model (GMM)

KMC gave us evidence to believe that \(X\) has 3 modes (sub-populations). Let’s use k=3 to learn a gaussian mixture model (GMM).

from sklearn.mixture import GaussianMixture

def get_gmm_labels(X, k):
    gmm = GaussianMixture(n_components=k, max_iter=50, covariance_type='spherical', random_state=37)
    labels = gmm.predict(X)
    labels = np.array([0 if label == 1 else 1 for label in labels])
    return labels, gmm

labels, gmm = get_gmm_labels(X, kms_scores[0][0])

Since we used k=3 from KMC to define the number of components for GMM, we should expect the number of components to be 3.


Here are the means, covariances and weights learned from GMM.

print(gmm.means_.reshape(1, -1)[0])
[5.0550979  0.01368472 9.96071725]
[1.03395708 1.01027696 1.86489054]
[0.56339635 0.27806376 0.15853989]
gaussians = {}
for label, mu, std in zip(range(gmm.n_components), gmm.means_.reshape(1, -1)[0], np.sqrt(gmm.covariances_)):
    print('{} : {}, {}'.format(label, mu, std))
    gaussians[label] = (mu, std)
0 : 5.055097898869679, 1.033957076311051
1 : 0.013684724179379982, 1.0102769579001163
2 : 9.960717253224319, 1.864890539118086

6.4. Simulation using Dirichlet

Now we can simulate samples using the gaussians mixture distribution weights as inputs for the Dirichlet distribution. Note that the code below gives the label with the highest probability per sampling run. After we sample a label from the Dirichlet, we then sample from the gaussian corresponding to the label to form, \(S\), the simulated data.

Note that we can simulate the data directly from the GMM gmm.sample(10), but where’s the fun in that?

from scipy.stats import dirichlet

mu_indices = np.argmax(dirichlet.rvs(gmm.weights_, size=X.shape[0], random_state=37), axis=1)
S = np.array([np.random.normal(gaussians[idx][0], gaussians[idx][1]) for idx in mu_indices]).reshape(-1, 1)

Plotting the distribution of \(S\) reveals something that looks like the distribution of \(X\).

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax = np.ravel(ax)

sns.distplot(S, ax=ax[0])
sns.distplot(X, ax=ax[1])
sns.distplot(S, ax=ax[2], hist=False, color='red', label='S')
sns.distplot(X, ax=ax[2], hist=False, color='blue', label='X')

ax[0].set_title('Distribution of S')
ax[1].set_title('Distribution of X')
ax[2].set_xlabel('S and X')
ax[2].set_ylabel('P(S) or P(X)')
ax[2].set_title('Distributions of S and X')


6.5. Testing the closeness of the simulated and empirical data

We can quantify the closeness of the simulated and empirical data using Pearson correlation (we have to sort the data first).

from scipy.stats import pearsonr

pearsonr(np.sort(S.reshape(1, -1)[0]), np.sort(X.reshape(1, -1)[0]))
(0.9986147691814674, 0.0)

We can also do a Q-Q plot to see how well the two data sets align.

import matplotlib.lines as mlines
from matplotlib.ticker import MaxNLocator

fig, ax = plt.subplots(figsize=(8, 8))
ax.plot(np.sort(S.reshape(1, -1)[0]), np.sort(X.reshape(1, -1)[0]), lw=2, alpha=0.7)
ax.set_title('Empirical (X) vs Simulated (S)')
xmin, xmax = ax.get_xbound()
ymin, ymax = ax.get_ybound()
line = mlines.Line2D([xmin, xmax], [ymin, ymax], c='red', ls='--', alpha=0.5)
_ = ax.add_line(line)

6.6. Computing the probability of a data point given the GMM

You may also estimate the probability of a data point given the GMM as follows,

\(P(x | G) = p'w\),


  • \(x\) is the data point,

  • \(G\) is the GMM model,

  • \(p\) is a vector of probabilities as estimated by the gaussians in the GMM, and

  • \(w\) is a vector of the weights of the gaussians in the GMM.

Below, we estimate the probability of 1,000 \(x \in X\) equally spaced, \(X \in [-5, 15]\).

from scipy import stats

x_min = -5
x_max = 15
total = 1000
probs = []
for v in np.linspace(x_min, x_max, num=total):
    p = np.array([stats.norm.pdf(v, gaussians[label][0], gaussians[label][1]) for label in range(3)])
    p = p.dot(gmm.weights_)
probs = np.array(probs)

We then plot the probabilities over \(X\), and this curve looks just like the simulated and empirical data’s distribution plots.

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(np.linspace(x_min, x_max, num=total), probs)
ax.set_title('Probability of X')
_ = ax.set_xlabel('X')

Lastly, we can verify that the curve represents a probability distribution through integration.

from scipy.integrate import simps
from numpy import trapz

dx = (x_max - x_min) / total

print('trapezoid integration area {:.5f}'.format(trapz(probs, dx=dx)))
print('simpsons integration area {:.5f}'.format(simps(probs, dx=dx)))
trapezoid integration area 0.99845
simpsons integration area 0.99845