Testing with simulations
A notebook for testing mushi
’s ability to invert data simulated under the forward model
[1]:
import mushi
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from scipy.special import expit
import stdpopsim
/opt/hostedtoolcache/Python/3.9.10/x64/lib/python3.9/site-packages/traitlets/traitlets.py:3044: FutureWarning: --rc={'figure.dpi': 96} for dict-traits is deprecated in traitlets 5.0. You can pass --rc <key=value> ... multiple times to add items to a dict.
warn(
Time grid
[2]:
change_points = np.logspace(0, np.log10(100000), 200)
t = np.concatenate((np.array([0]), change_points))
Define true demographic history
[3]:
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("Zigzag_1S14")
ddb = model.get_demography_debugger()
eta_true = mushi.eta(change_points,
1 / ddb.coalescence_rate_trajectory(steps=t,
lineages={0: 2},
double_step_validation=False)[0])
plt.figure(figsize=(3.5, 3.5))
eta_true.plot(c='k');
Mutation rate history \(\mu(t)\)
A 96 dimensional history with a mixture of two latent signature: constant and pulse.
[4]:
flat = np.ones_like(t)
pulse = expit(.1 * (t - 100)) - expit(.01 * (t - 2000))
ramp = expit(-.01 * (t - 100))
cols = 3
Z = np.zeros((len(t), cols))
mu0 = 1
np.random.seed(0)
Z[:, 0] = mu0 * (1 * flat + .5 * pulse)
Z[:, 1] = mu0 * (.5 * flat + .4 * ramp)
Z[:, 2] = 94 * mu0 * flat
mutation_types = ['TCC>TTC', 'GAA>GGA', None]
mu_true = mushi.mu(change_points, Z, mutation_types)
plt.figure(figsize=(4, 4))
mu_true.plot(('TCC>TTC',), alpha=0.75, lw=3, clr=False)
mu_true.plot(('GAA>GGA',), alpha=0.75, lw=3, clr=False);
Estimate the total mutation rate using \(t=0\)
[5]:
mu0 = mu_true.Z[0, :].sum()
print(mu0)
95.79244612935578
Simulate a \(k\)-SFS
We’ll sample 200 haplotypes
note that this simulation will have a slightly varying total mutation rate
[6]:
n = 200
ksfs = mushi.kSFS(n=n)
ksfs.simulate(eta_true, mu_true, r=0.02, seed=1)
plt.figure(figsize=(4, 3))
ksfs.plot_total(kwargs=dict(ls='', marker='.'))
plt.xscale('log')
plt.yscale('log')
plt.figure(figsize=(4, 3))
ksfs.plot(('TCC>TTC',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='o'))
ksfs.plot(('GAA>GGA',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='o'))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Number of segregating sites
[7]:
ksfs.X.sum()
[7]:
DeviceArray(9507552, dtype=int64)
TMRCA CDF
[8]:
plt.figure(figsize=(3.5, 3.5))
plt.plot(change_points, ksfs.tmrca_cdf(eta_true))
plt.xlabel('$t$')
plt.ylabel('TMRCA CDF')
plt.ylim([0, 1])
plt.xscale('log')
plt.tight_layout()
Inference
convergence parameters and time grid
[9]:
convergence = dict(tol=0, max_iter=100, trend_kwargs=dict(max_iter=20))
pts = 100
Infer \(\eta(t)\)
[10]:
trend_penalties = ((0, 1e1), (1, 1e0))
folded = False
ksfs.clear_eta()
ksfs.clear_mu()
ksfs.r = None
ksfs.infer_eta(mu0,
*trend_penalties,
ridge_penalty=1e-4,
pts=pts, **convergence, verbose=True, folded=folded)
if ksfs.r is not None:
print(f'inferred ancestral misidentification rate: {ksfs.r:.3f}')
plt.figure(figsize=(8, 4))
plt.subplot(121)
ksfs.plot_total(kwargs=dict(ls='', marker='o', ms=5, c='k', alpha=0.75),
line_kwargs=dict(c='C0', alpha=0.75, lw=3),
fill_kwargs=dict(color='C0', alpha=0.1),
folded=folded)
plt.xscale('log')
plt.yscale('log')
plt.subplot(122)
eta_true.plot(c='k', lw=2, label='true')
ksfs.eta.plot(lw=3, alpha=0.75, label='inferred')
plt.legend()
plt.tight_layout()
plt.show()
initial objective -1.037540e+08
iteration 100, objective -1.039e+08, relative change 4.006e-08
maximum iteration 100 reached with relative change in objective function 4e-08
inferred ancestral misidentification rate: 0.020
Infer \(\boldsymbol\mu(t)\)
[11]:
ksfs.clear_mu()
trend_penalties = ((0, 2e2), (3, 1e-1))
ksfs.infer_mush(*trend_penalties, ridge_penalty=1e-4,
**convergence, verbose=True)
plt.figure(figsize=(8, 4))
plt.subplot(121)
ksfs.plot(('TCC>TTC',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='.', ms=10, mfc='none', c='C0'),
line_kwargs=dict(alpha=0.75, lw=2, c='C0'))
ksfs.plot(('GAA>GGA',), clr=True, kwargs=dict(alpha=0.75, ls='', marker='.', ms=10, mfc='none', c='C1'),
line_kwargs=dict(alpha=0.75, lw=2, c='C1'))
plt.subplot(122)
mu_true.plot(('TCC>TTC',), alpha=0.75, lw=2, c='C0')
ksfs.mu.plot(('TCC>TTC',), alpha=0.75, lw=3, ls='--', c='C0')
mu_true.plot(('GAA>GGA',), alpha=0.75, lw=2, c='C1')
ksfs.mu.plot(('GAA>GGA',), alpha=0.75, lw=3, ls='--', c='C1')
plt.tight_layout()
plt.show()
initial objective -1.028993e+08
iteration 100, objective -1.029e+08, relative change 4.899e-10
maximum iteration 100 reached with relative change in objective function 4.9e-10
[ ]: