6. Conditional Multivariate Normal Distribution

In this notebook we will learn about the conditional multivariate normal (MVN) distribution. In particular, we want to estimate the expected value (or the mean) of some subset of variables given that another subset has been conditioned on. Though the notation is quasi-dense, it is not terribly difficult to produce a conditional MVN from a marginal MVN distribution.

6.1. Case 1, pair

  • \(X_0 \rightarrow X_1\)

[1]:
import numpy as np
from numpy.random import normal

np.random.seed(37)

def print_vector(title, v):
    print(title)
    s = ', '.join([f'{i:.5f}' for i in v])
    print(f'[{s}]')

def print_matrix(title, m):
    print(title)
    s = [[f'{i:.5f}' for i in v] for v in m]
    s = '\n'.join([f'[{", ".join(i)}]' for i in s])
    print(s)

N = 10000
x0 = normal(0, 1, N)
x1 = normal(1 + 2 * x0, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print_vector('mean', M)
print_matrix('cov', S)
mean
[0.00172, 0.99428]
cov
[0.99070, 1.98924]
[1.98924, 5.00419]
[2]:
M[0] + S[0,1] / S[1,1] * (0.5 - M[1])
[2]:
-0.194758846999118
[3]:
M[1] + S[1,0] / S[0,0] * (0.5 - M[0])
[3]:
1.9947740092524469
[4]:
S[0,0] - S[0,1] / S[1,1] * S[1,0]
[4]:
0.1999450140696054
[5]:
S[1,1] - S[1,0] / S[0,0] * S[1,0]
[5]:
1.0099559400241556

6.2. Case 2, serial

  • \(X_0 \rightarrow X_1 \rightarrow X_2\)

[6]:
from collections import namedtuple
from numpy.linalg import inv
import warnings

warnings.filterwarnings('ignore')
COV = namedtuple('COV', 'C11 C12 C21 C22 C22I')

def to_row_indices(indices):
    return [[i] for i in indices]

def to_col_indices(indices):
    return indices

def get_covariances(i1, i2, S):
    r = to_row_indices(i1)
    c = to_col_indices(i1)
    C11 = S[r,c]

    r = to_row_indices(i1)
    c = to_col_indices(i2)
    C12 = S[r,c]

    r = to_row_indices(i2)
    c = to_col_indices(i1)
    C21 = S[r,c]

    r = to_row_indices(i2)
    c = to_col_indices(i2)
    C22 = S[r,c]

    C22I = inv(C22)

    return COV(C11, C12, C21, C22, C22I)

def compute_means(a, M, C, i1, i2):
    a = np.array([2.0])
    return M[i1] + C.C12.dot(C.C22I).dot(a - M[i2])

def compute_covs(C):
    return C.C11 - C.C12.dot(C.C22I).dot(C.C21)

def update_mean(m, a, M, i1, i2):
    v = np.copy(M)
    for i, mu in zip(i1, m):
        v[i] = mu
    for i, mu in zip(i2, a):
        v[i] = mu
    return v

def update_cov(c, S, i1, i2):
    m = np.copy(S)
    rows, cols = c.shape
    for row in range(rows):
        for col in range(cols):
            m[i1[row],i1[col]] = c[row,col]
    for i in i2:
        m[i,i] = 0.01
    return m

def update_mean_cov(v, iv, M, S):
    if v is None or iv is None or len(v) == 0 or len(iv) == 0:
        return np.copy(M), np.copy(S)
    i2 = iv.copy()
    i1 = [i for i in range(S.shape[0]) if i not in i2]

    C = get_covariances(i1, i2, S)
    m = compute_means(v, M, C, i1, i2)
    c = compute_covs(C)
    M_u = update_mean(m, v, M, i1, i2)
    S_u = update_cov(c, S, i1, i2)
    return M_u, S_u
[7]:
N = 10000
x0 = normal(0, 1, N)
x1 = normal(1 + 2 * x0, 1, N)
x2 = normal(1 + 2 * x1, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print_vector('mean', M)
print('>')
print_matrix('cov', S)
print('>')
print_matrix('corr', np.corrcoef(X.T))
mean
[0.00499, 0.99888, 3.01284]
>
cov
[0.98453, 1.98373, 3.95254]
[1.98373, 5.01127, 9.99960]
[3.95254, 9.99960, 20.97023]
>
corr
[1.00000, 0.89309, 0.86988]
[0.89309, 1.00000, 0.97545]
[0.86988, 0.97545, 1.00000]
[8]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print_vector('mean', M_u)
print('>')
print_matrix('cov', S_u)
print('>')
print_matrix('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))
mean
[0.40128, 2.00000, 5.01049]
>
cov
[0.19927, 1.98373, -0.00584]
[1.98373, 0.01000, 9.99960]
[-0.00584, 9.99960, 1.01681]
>
corr
[1.00000, -0.02114, 0.77507]
[-0.02114, 1.00000, 0.04972]
[0.77507, 0.04972, 1.00000]

6.3. Case 3, diverging

  • \(X_0 \leftarrow X_1 \rightarrow X_2\)

[9]:
N = 10000

x1 = normal(0, 1, N)
x0 = normal(1 + 4.0 * x1, 1, N)
x2 = normal(1 + 2.0 * x1, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print_vector('mean', M)
print('>')
print_matrix('cov', S)
print('>')
print_matrix('corr', np.corrcoef(X.T))
mean
[0.98517, -0.00131, 1.00396]
>
cov
[16.98775, 3.99342, 7.96496]
[3.99342, 0.99839, 1.98856]
[7.96496, 1.98856, 4.93653]
>
corr
[1.00000, 0.96968, 0.86977]
[0.96968, 1.00000, 0.89573]
[0.86977, 0.89573, 1.00000]
[10]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print_vector('mean', M_u)
print('>')
print_matrix('cov', S_u)
print('>')
print_matrix('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))
mean
[8.99009, 2.00000, 4.99009]
>
cov
[1.01467, 3.99342, 0.01100]
[3.99342, 0.01000, 1.98856]
[0.01100, 1.98856, 0.97577]
>
corr
[1.00000, 0.11161, 0.56147]
[0.11161, 1.00000, 0.08156]
[0.56147, 0.08156, 1.00000]

6.4. Case 4, converging

  • \(X_0 \rightarrow X_1 \leftarrow X_2\)

[11]:
N = 10000

x0 = normal(0, 1, N)
x2 = normal(0, 1, N)
x1 = normal(1 + 2 * x0 + 3 * x2, 1, N)

X = np.hstack([x0.reshape(-1, 1), x1.reshape(-1, 1), x2.reshape(-1, 1)])
M = np.mean(X, axis=0)
S = np.cov(X.T)

print_vector('mean', M)
print('>')
print_matrix('cov', S)
print('>')
print_matrix('corr', np.corrcoef(X.T))
mean
[-0.00565, 0.97046, -0.01113]
>
cov
[0.97729, 1.99763, 0.01513]
[1.99763, 14.06103, 3.01565]
[0.01513, 3.01565, 0.99463]
>
corr
[1.00000, 0.53888, 0.01535]
[0.53888, 1.00000, 0.80638]
[0.01535, 0.80638, 1.00000]
[12]:
M_u, S_u = update_mean_cov(np.array([2.0]), [1], M, S)

print_vector('mean', M_u)
print('>')
print_matrix('cov', S_u)
print('>')
print_matrix('corr', np.corrcoef(np.random.multivariate_normal(M_u, S_u, N*10).T))
mean
[0.14062, 2.00000, 0.20968]
>
cov
[0.69349, 1.99763, -0.41330]
[1.99763, 0.01000, 3.01565]
[-0.41330, 3.01565, 0.34787]
>
corr
[1.00000, 0.00407, 0.55062]
[0.00407, 1.00000, 0.00786]
[0.55062, 0.00786, 1.00000]