11. Conditional Mutual Information for Gaussian Variables

The conditional mutual information for two continuous variables \(X\) and \(Y\), given a third \(Z\), is defined as follows.

\(I(X;Y|Z) = \int_Z \int_Y \int_X p(x, y, z) \log \cfrac{p(z) p(x, y, z)}{p(x, z)p(y, z)} dx dy dz\)

Computing the conditional mutual information is prohibitive since the number of possible values of \(X\), \(Y\) and \(Z\) could be very large, and the product of the numbers of possible values is even larger. Here, we will use an approximation to computing the mutual information. First, we will assume that the \(X\), \(Y\) and \(Z\) are gaussian distributed. Second, we will not exhaustively enumerate all possible combinations of \(x \in X\), \(y \in Y\) and \(z \in Z\). Instead, we will only take equally spaced and an equal number of \(x \in X\), \(y \in Y\) and \(z \in Z\).

11.1. Simulation

Let’s simulate data from 3 graphical models.

  • \(M_S = X_1 \rightarrow X_3 \rightarrow X_2\)

  • \(M_D = X_1 \leftarrow X_3 \rightarrow X_2\)

  • \(M_C = X_1 \rightarrow X_3 \leftarrow X_2\)

\(M_S\) is called a serial model, \(M_D\) is called a diverging model and \(M_C\) is called a converging model.

[1]:
import numpy as np
import pandas as pd
from scipy.stats import multivariate_normal
import itertools

np.random.seed(37)

class Data(object):
    def __init__(self, data, means, cov, points=50):
        self.data = data
        self.means = means
        self.cov = cov
        self.df = pd.DataFrame(data, columns=['x1', 'x2', 'x3'])
        self.p_xyz = multivariate_normal(means, cov)
        self.p_xz = multivariate_normal(means[[0, 2]], cov[[0, 2]][:, [0, 2]])
        self.p_yz = multivariate_normal(means[[1, 2]], cov[[1, 2]][:, [1, 2]])
        self.p_z = multivariate_normal(means[2], cov[2, 2])
        self.x_vals = np.linspace(self.df.x1.min(), self.df.x1.max(), num=points, endpoint=True)
        self.y_vals = np.linspace(self.df.x2.min(), self.df.x2.max(), num=points, endpoint=True)
        self.z_vals = np.linspace(self.df.x3.min(), self.df.x3.max(), num=points, endpoint=True)

    def get_cmi(self):
        x_vals = self.x_vals
        y_vals = self.y_vals
        z_vals = self.z_vals
        prod = itertools.product(*[x_vals, y_vals, z_vals])

        p_z = self.p_z
        p_xz = self.p_xz
        p_yz = self.p_yz
        p_xyz = self.p_xyz
        quads = ((p_xyz.pdf([x, y, z]), p_z.pdf(z), p_xz.pdf([x, z]), p_yz.pdf([y, z])) for x, y, z in prod)

        cmi = sum((xyz * (np.log(z) + np.log(xyz) - np.log(xz) - np.log(yz)) for xyz, z, xz, yz in quads))
        return cmi


def get_serial(N=1000):
    x1 = np.random.normal(1, 1, N)
    x3 = np.random.normal(1 + 3.5 * x1, 1, N)
    x2 = np.random.normal(1 - 2.8 * x3, 3, N)

    data = np.vstack([x1, x2, x3]).T
    means = data.mean(axis=0)
    cov = np.cov(data.T)

    return Data(data, means, cov)

def get_diverging(N=1000):
    x3 = np.random.normal(1, 1, N)
    x1 = np.random.normal(1 + 2.8 * x3, 1, N)
    x2 = np.random.normal(1 - 2.8 * x3, 3, N)

    data = np.vstack([x1, x2, x3]).T
    means = data.mean(axis=0)
    cov = np.cov(data.T)

    return Data(data, means, cov)

def get_converging(N=1000):
    x1 = np.random.normal(2.8, 1, N)
    x2 = np.random.normal(8.8, 3, N)
    x3 = np.random.normal(1 + 0.8 * x1 + 0.9 * x2, 1, N)


    data = np.vstack([x1, x2, x3]).T
    means = data.mean(axis=0)
    cov = np.cov(data.T)

    return Data(data, means, cov)

m_s = get_serial()
m_d = get_diverging()
m_c = get_converging()

11.2. Estimate conditional mutual information

As, you can see, when testing for conditional mutual information, \(I(X_1, X_2 | X_3)\) for both serial and diverging structure (data) suggest small conditional dependence. However, \(I(X_1, X_2 | X_3)\) for the converging structure suggest larger conditional dependence.

[2]:
%%time
m_s.get_cmi()
CPU times: user 15.8 s, sys: 278 ms, total: 16.1 s
Wall time: 14.8 s
[2]:
0.012372411431840816
[3]:
%%time
m_d.get_cmi()
CPU times: user 14.7 s, sys: 54.4 ms, total: 14.7 s
Wall time: 14.7 s
[3]:
9.612131185101602e-05
[4]:
%%time
m_c.get_cmi()
CPU times: user 14.7 s, sys: 84.7 ms, total: 14.7 s
Wall time: 14.7 s
[4]:
11.209703669891077