Testing with simulations

A notebook for testing mushi’s ability to invert data simulated under the forward model

[1]:
import mushi

import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from scipy.special import expit
import stdpopsim
/opt/hostedtoolcache/Python/3.9.10/x64/lib/python3.9/site-packages/traitlets/traitlets.py:3044: FutureWarning: --rc={'figure.dpi': 96} for dict-traits is deprecated in traitlets 5.0. You can pass --rc <key=value> ... multiple times to add items to a dict.
  warn(

Time grid

[2]:
change_points = np.logspace(0, np.log10(100000), 200)
t = np.concatenate((np.array([0]), change_points))

Define true demographic history

[3]:
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("Zigzag_1S14")
ddb = model.get_demography_debugger()
eta_true = mushi.eta(change_points,
                     1 / ddb.coalescence_rate_trajectory(steps=t,
                                                         lineages={0: 2},
                                                         double_step_validation=False)[0])

plt.figure(figsize=(3.5, 3.5))
eta_true.plot(c='k');
../_images/notebooks_simulation_5_0.svg

Mutation rate history \(\mu(t)\)

A 96 dimensional history with a mixture of two latent signature: constant and pulse.

[4]:
flat = np.ones_like(t)
pulse = expit(.1 * (t - 100)) - expit(.01 * (t - 2000))
ramp = expit(-.01 * (t - 100))
cols = 3
Z = np.zeros((len(t), cols))
mu0 = 1
np.random.seed(0)

Z[:, 0] = mu0 * (1 * flat + .5 * pulse)
Z[:, 1] = mu0 * (.5 * flat + .4 * ramp)
Z[:, 2] = 94 * mu0 * flat

mutation_types = ['TCC>TTC', 'GAA>GGA', None]

mu_true = mushi.mu(change_points, Z, mutation_types)

plt.figure(figsize=(4, 4))
mu_true.plot(('TCC>TTC',), alpha=0.75, lw=3, clr=False)
mu_true.plot(('GAA>GGA',), alpha=0.75, lw=3, clr=False);
../_images/notebooks_simulation_7_0.svg

Estimate the total mutation rate using \(t=0\)

[5]:
mu0 = mu_true.Z[0, :].sum()
print(mu0)
95.79244612935578

Simulate a \(k\)-SFS

  • We’ll sample 200 haplotypes

  • note that this simulation will have a slightly varying total mutation rate

[6]:
n = 200
ksfs = mushi.kSFS(n=n)
ksfs.simulate(eta_true, mu_true, r=0.02, seed=1)

plt.figure(figsize=(4, 3))
ksfs.plot_total(kwargs=dict(ls='', marker='.'))
plt.xscale('log')
plt.yscale('log')

plt.figure(figsize=(4, 3))
ksfs.plot(('TCC>TTC',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='o'))
ksfs.plot(('GAA>GGA',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='o'))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
../_images/notebooks_simulation_11_1.svg
../_images/notebooks_simulation_11_2.svg

Number of segregating sites

[7]:
ksfs.X.sum()
[7]:
DeviceArray(9507552, dtype=int64)

TMRCA CDF

[8]:
plt.figure(figsize=(3.5, 3.5))
plt.plot(change_points, ksfs.tmrca_cdf(eta_true))
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
plt.xscale('log')
plt.tight_layout()
../_images/notebooks_simulation_15_0.svg

Inference

convergence parameters and time grid

[9]:
convergence = dict(tol=0, max_iter=100, trend_kwargs=dict(max_iter=20))
pts = 100

Infer \(\eta(t)\)

[10]:
trend_penalties = ((0, 1e1), (1, 1e0))

folded = False

ksfs.clear_eta()
ksfs.clear_mu()
ksfs.r = None

ksfs.infer_eta(mu0,
               *trend_penalties,
               ridge_penalty=1e-4,
               pts=pts, **convergence, verbose=True, folded=folded)

if ksfs.r is not None:
    print(f'inferred ancestral misidentification rate: {ksfs.r:.3f}')

plt.figure(figsize=(8, 4))

plt.subplot(121)
ksfs.plot_total(kwargs=dict(ls='', marker='o', ms=5, c='k', alpha=0.75),
                line_kwargs=dict(c='C0', alpha=0.75, lw=3),
                fill_kwargs=dict(color='C0', alpha=0.1),
                folded=folded)
plt.xscale('log')
plt.yscale('log')

plt.subplot(122)
eta_true.plot(c='k', lw=2, label='true')
ksfs.eta.plot(lw=3, alpha=0.75, label='inferred')
plt.legend()
plt.tight_layout()
plt.show()
initial objective -1.037540e+08
iteration 100, objective -1.039e+08, relative change 4.006e-08
maximum iteration 100 reached with relative change in objective function 4e-08
inferred ancestral misidentification rate: 0.020
../_images/notebooks_simulation_19_1.svg

Infer \(\boldsymbol\mu(t)\)

[11]:
ksfs.clear_mu()

trend_penalties = ((0, 2e2), (3, 1e-1))

ksfs.infer_mush(*trend_penalties, ridge_penalty=1e-4,
                **convergence, verbose=True)

plt.figure(figsize=(8, 4))

plt.subplot(121)
ksfs.plot(('TCC>TTC',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='.', ms=10, mfc='none', c='C0'),
          line_kwargs=dict(alpha=0.75, lw=2, c='C0'))
ksfs.plot(('GAA>GGA',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='.', ms=10, mfc='none', c='C1'),
          line_kwargs=dict(alpha=0.75, lw=2, c='C1'))

plt.subplot(122)
mu_true.plot(('TCC>TTC',), alpha=0.75, lw=2, c='C0')
ksfs.mu.plot(('TCC>TTC',), alpha=0.75, lw=3, ls='--', c='C0')
mu_true.plot(('GAA>GGA',), alpha=0.75, lw=2, c='C1')
ksfs.mu.plot(('GAA>GGA',), alpha=0.75, lw=3, ls='--', c='C1')
plt.tight_layout()
plt.show()
initial objective -1.028993e+08
iteration 100, objective -1.029e+08, relative change 4.899e-10
maximum iteration 100 reached with relative change in objective function 4.9e-10
../_images/notebooks_simulation_21_1.svg
[ ]: