Call compartment#

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import anndata
import scanpy as sc
from glob import glob
import time
import cooler
import hba_data
from itertools import cycle, islice
from sklearn.decomposition import PCA
from scipy.stats import pearsonr
from concurrent.futures import ProcessPoolExecutor, as_completed

mpl.style.use('default')
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Helvetica'
indir = '/home/jzhou_salk_edu/sky_workdir/hba/loop_majortype/'
outdir = '/home/jzhou_salk_edu/sky_workdir/hba/compartment_majortype/'
res = 100000
chrom_size_path = '/home/jzhou_salk_edu/sky_workdir/hba/ref/hg38.main.chrom.sizes'
chrom_sizes = cooler.read_chromsizes(chrom_size_path, all_names=True)
chrom_sizes = chrom_sizes.iloc[:23]
cpg = pd.read_csv('/home/jzhou_salk_edu/sky_workdir/hba/ref/hg38.100kbin.CpG.txt', header=0, index_col=3, sep='\t')
cpg['CpG_density'] = cpg['14_user_patt_count'] / (cpg['13_seq_len'] - cpg['11_num_N'])

Fit PC model with merged raw#

Qall = []
binall = []
ct = 'merged'
# cool = cooler.Cooler(f'/gale/netapp/entex/HBA/snm3C/compartment_majortype/{ct}/{ct}.raw.mcool::resolutions/100000')
cool = cooler.Cooler(f'{indir}{ct}/{ct}/{ct}.raw.mcool::resolutions/100000')
fig, axes = plt.subplots(5, 5, figsize=(15,15))
for i,c in enumerate(chrom_sizes.index):
    Q = cool.matrix(balance=False, sparse=True).fetch(c).toarray()
    Q = Q - np.diag(np.diag(Q))
    rowsum = Q.sum(axis=0)
    thres = [np.percentile(rowsum[rowsum>0], 50), np.percentile(rowsum[rowsum>0], 99)]
    thres.append(thres[0]*2-thres[1])
    ax = axes.flatten()[i]
    sns.histplot(rowsum, bins=100, ax=ax)
    for t in thres:
        ax.plot([t, t], [0, ax.get_ylim()[1]*0.5], c='r')
    ax.set_title(c)
    binfilter = (rowsum>thres[-1])
    binall.append(binfilter)
    Q = Q[binfilter][:, binfilter]
    Qall.append(Q)
    print(c)

for i in range(chrom_sizes.shape[0], axes.flatten().shape[0]):
    axes.flatten()[i].axis('off')
    
chr1
chr2
chr3
chr4
chr5
chr6
chr7
chr8
chr9
chr10
chr11
chr12
chr13
chr14
chr15
chr16
chr17
chr18
chr19
chr20
chr21
chr22
chrX
../../_images/a60ac28cd02d4e73b34d029385a4caa19863dc36cfc0b55428b28b7192cd2013.png
fig, axes = plt.subplots(5, 5, figsize=(15,15))
for i,c in enumerate(chrom_sizes.index):
    ax = axes.flatten()[i]
    n_bins = int(chrom_sizes[c] // 100000) + 1
    tmp = np.zeros((n_bins, n_bins))
    tmp[np.ix_(binall[i], binall[i])] = Qall[i]
    ax.imshow(tmp, cmap='bwr', vmin=-np.percentile(Qall[i], 95), vmax=np.percentile(Qall[i], 95))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(c, fontsize=15)

for i in range(chrom_sizes.shape[0], axes.flatten().shape[0]):
    axes.flatten()[i].axis('off')
    
plt.tight_layout()
# plt.savefig(f'{indir}/plot/celltype_A_decay.pdf', transparent=True)
findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
findfont: Generic family 'sans-serif' not found because none of the following families were found: Helvetica
../../_images/584665393e130b94f4d62996b783b9f027ebfaeac23fca362f3afb2ebfab3e01.png
# comp = pd.read_hdf(f'{indir}compartment_majortype/comp_raw_merged.hdf', key='data')

Call = []
pcall = []
modelall = []
for k,chrom in enumerate(chrom_sizes.index[:-1]):
    Q = Qall[k].copy()
    decay = np.array([np.mean(np.diag(Q, i)) for i in range(Q.shape[0])])
    E = np.zeros(Q.shape)
    row, col = np.diag_indices(E.shape[0])
    E[row, col] = 1
    for i in range(1, E.shape[0]):
        E[row[:-i], col[i:]] = (Q[row[:-i], col[i:]] + 1e-5) / (decay[i] + 1e-5)
    E = E + E.T
    C = np.corrcoef(np.log2(E + 0.001))
    Call.append(C)
    binfilter = binall[k]
#     tmp = comp.loc[ct, (comp.columns.str.split('_').str[0]==chrom)]
#     tmp.index = [int(xx.split('_')[1]) for xx in tmp.index]
#     pcall.append(tmp[np.where(binfilter)[0]].values)
    pca = PCA(n_components=2)
    pc = pca.fit_transform(C)
    cpgtmp = cpg.loc[cpg['#1_usercol']==chrom, 'CpG_density'].values[binfilter]
#     i = 0
#     if np.abs(pearsonr(cpgtmp, pc[:,0])[0])>np.abs(pearsonr(cpgtmp, pc[:,1])[0]):
#         i = 0
#     else:
#         i = 1
    r = []
    for i in range(2):
        labels, groups = pd.qcut(pc[:,i], 50, labels=False, retbins=True)
        sad = np.array([[E[np.ix_(labels==i, labels==j)].sum() for i in range(50)] for j in range(50)])
        count = np.array([[(labels==i).sum()*(labels==j).sum() for i in range(50)] for j in range(50)])
        sad = sad / count
        r.append((sad[:10, :10].sum() + sad[-10:, -10:].sum()) / (sad[:10, -10:].sum() + sad[-10:, :10].sum()))
    if r[0]>r[1]:
        i = 0
    else:
        i = 1
    if pearsonr(cpgtmp, pc[:,i])[0]>0:
        pc = pc[:,i]
        modelall.append(pca.components_[i])
    else:
        pc = -pc[:,i]
        modelall.append(-pca.components_[i])
    pcall.append(pc)
    print(chrom, i)
chr1 0
chr2 0
chr3 0
chr4 0
chr5 0
chr6 0
chr7 0
chr8 0
chr9 0
chr10 0
chr11 0
chr12 0
chr13 0
chr14 0
chr15 0
chr16 0
chr17 0
chr18 0
chr19 0
chr20 0
chr21 0
chr22 0
fig, axes = plt.subplots(2, 22, figsize=(66,4), gridspec_kw={'height_ratios':[5,0.2]}, sharex='col')
for i,c in enumerate(chrom_sizes.index[:-1]):
    ax = axes[0,i]
    n_bins = int(chrom_sizes[c] // 100000) + 1
    tmp = np.zeros((n_bins, n_bins))
    tmp[np.ix_(binall[i], binall[i])] = Call[i]
    ax.imshow(tmp, cmap='bwr', vmin=np.percentile(Call[i], 5), vmax=np.percentile(Call[i], 95))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(c, fontsize=15)

    ax = axes[1,i]
    # ax.set_title('PC1', fontsize=10)
    sns.despine(bottom=True, left=True, ax=ax)
    tmp = np.zeros(n_bins)
    tmp[binall[i]] = pcall[i]# / np.std(pcall[i])
    x, y = np.arange(n_bins), tmp
    # x, y = np.arange(pcall[i].shape[0]), pcall[i]
    ax.fill_between(x, y, 0, where=y >= 0, facecolor='C3', interpolate=True)
    ax.fill_between(x, y, 0, where=y <= 0, facecolor='C0', interpolate=True)
    ax.set_yticks([])
    ax.set_ylim([np.percentile(y, 1), np.percentile(y, 99)])

plt.tight_layout()
# plt.savefig(f'{indir}/plot/celltype_compraw.pdf', transparent=True)
../../_images/4c6e99aefa68aa672145d3c78ed91b4bb44a2a587c0a6d7e4b79d0f6d92aa290.png
np.save(f'{outdir}binfilter_raw.npy', binall)
np.save(f'{outdir}pcloading_raw.npy', modelall)

Saddle plot of merged raw pc transformed compartment#

leg = ['L23_IT', 'L4_IT', 'L56_NP', 'L5_ET', 'L5_IT', 'L6_CT', 'L6_IT', 'L6_IT_Car3', 'L6b', 
       'Lamp5', 'Lamp5_LHX6', 'Sncg', 'Vip', 'Pvalb', 'Pvalb_ChC', 'Sst', 
       'MSN_D1', 'MSN_D2', 'SubCtx', 'Amy', 'CHD7', 'Foxp2', 
       'ASC', 'ODC', 'OPC', 'MGC', 'EC', 'PC', 'VLMC', 'merged']
legname = ['L2/3-IT', 'L4-IT', 'L5/6-NP', 'L5-ET', 'L5-IT', 'L6-CT', 'L6-IT', 'L6-IT-Car3', 'L6b', 
       'Lamp5', 'Lamp5-Lhx6', 'Sncg', 'Vip', 'Pvalb', 'Pvalb-ChC', 'Sst', 
       'MSN-D1', 'MSN-D2', 'SubCtx-Cplx', 'Amy-Exc', 'Chd7', 'Foxp2', 
       'ASC', 'ODC', 'OPC', 'MGC', 'EC', 'PC', 'VLMC', 'merged'
      ]
# binall = np.load(f'{outdir}binfilter_raw.npy', allow_pickle=True)
# modelall = np.load(f'{outdir}pcloading_raw.npy', allow_pickle=True)
def compsaddle(cool):
    sad = np.zeros((50, 50))
    count = np.zeros((50, 50))
    comptmp = []
    for k,chrom in enumerate(chrom_sizes.index[:-1]):
        Q = cool.matrix(balance=False, sparse=True).fetch(chrom).toarray()
        Q = Q - np.diag(np.diag(Q))
        pc = np.zeros(Q.shape[0])
        binfilter = binall[k]
        Q = Q[binfilter][:, binfilter]
        decay = np.array([np.mean(np.diag(Q, i)) for i in range(Q.shape[0])])
        E = np.zeros(Q.shape)
        row, col = np.diag_indices(E.shape[0])
        E[row, col] = 1
        for i in range(1, E.shape[0]):
            E[row[:-i], col[i:]] = (Q[row[:-i], col[i:]] + 1e-5) / (decay[i] + 1e-5)
        E = E + E.T
        C = np.corrcoef(np.log2(E + 0.001))
        pc = (C-np.mean(C, axis=0)).dot(modelall[k])
        comptmp.append(pc)
        labels, groups = pd.qcut(pc, 50, labels=False, retbins=True)
        sad += np.array([[E[np.ix_(labels==i, labels==j)].sum() for i in range(50)] for j in range(50)])
        count += np.array([[(labels==i).sum()*(labels==j).sum() for i in range(50)] for j in range(50)])
    return np.concatenate(comptmp), sad, count
mode = 'raw'
sad = np.zeros((len(leg),50,50))
count = np.zeros((len(leg),50,50))
comp = []
cpu = 10
with ProcessPoolExecutor(cpu) as executor:
    futures = {}
    for t,ct in enumerate(leg):
        if mode=='impute':
            cool = cooler.Cooler(f'{indir}{ct}/{ct}.Q.cool')
        elif mode=='raw':
            cool = cooler.Cooler(f'{indir}{ct}/{ct}/{ct}.raw.mcool::resolutions/100000')
        future = executor.submit(
            compsaddle,
            cool=cool,
        )
        futures[future] = t

    for future in as_completed(futures):
        t = futures[future]
        ct = leg[t]
        xx, yy, zz = future.result()
        comp.append(pd.Series(xx, name=ct))
        sad[t] += yy
        count[t] += zz
        print(f'{ct} finished')
        
L5_ET finished
L56_NP finished
L23_IT finished
L4_IT finished
L6_IT finished
L6b finished
L5_IT finished
L6_IT_Car3 finished
Lamp5 finished
L6_CT finished
Lamp5_LHX6 finished
SubCtx finished
Sncg finished
Pvalb_ChC finished
Vip finished
Pvalb finished
Sst finished
MSN_D1 finished
MSN_D2 finished
Amy finished
CHD7 finished
Foxp2 finished
OPC finished
ODC finished
ASC finished
VLMC finished
EC finished
PC finished
MGC finished
merged finished
bins_df = cool.bins()[:]
bins_df.index = bins_df['chrom'].astype(str) + '-' + (bins_df['start'] // res).astype(str)
compidx = np.concatenate([bins_df.index[bins_df['chrom']==c][binall[i]] for i,c in enumerate(chrom_sizes.index[:-1])])
comp = pd.concat(comp, axis=1)
comp.index = compidx
comp.to_hdf(f'{outdir}comp_{mode}_mergerawpca.hdf', key='data')
sad = sad / count
np.save(f'{outdir}saddle_{mode}_mergerawpca.npy', sad)
ws = 10
compstr = [[sad[i][:ws, :ws].mean(), sad[i][-ws:, -ws:].mean(), sad[i][:ws, -ws:].mean(), sad[i][-ws:, :ws].mean(), 
            (sad[i][:ws, :ws].sum() + sad[i][-ws:, -ws:].sum()) / (sad[i][:ws, -ws:].sum() + sad[i][-ws:, :ws].sum())]
            for i in range(len(leg))]
compstr = pd.DataFrame(compstr, index=leg, columns=['BB', 'AA', 'BA', 'AB', 'strength'])
compstr.sort_values('strength')
BB AA BA AB strength
L5_ET 1.348505 1.262203 0.708481 0.708481 1.842468
L6b 1.291788 1.330709 0.683750 0.683750 1.917731
SubCtx 1.236652 1.368936 0.658978 0.658978 1.976991
L6_IT_Car3 1.299906 1.325469 0.626082 0.626082 2.096671
L56_NP 1.311511 1.331668 0.626919 0.626919 2.108070
L6_IT 1.298863 1.347858 0.623379 0.623379 2.122883
Amy 1.343647 1.332711 0.614503 0.614503 2.177662
Foxp2 1.291581 1.370952 0.607887 0.607887 2.189991
Sncg 1.345780 1.375168 0.619657 0.619657 2.195527
L6_CT 1.291535 1.397869 0.611131 0.611131 2.200350
CHD7 1.297236 1.364787 0.600250 0.600250 2.217427
merged 1.231826 1.475701 0.605544 0.605544 2.235614
Pvalb_ChC 1.274394 1.436267 0.603899 0.603899 2.244299
L23_IT 1.256678 1.393642 0.583772 0.583772 2.269996
L5_IT 1.263534 1.455531 0.597170 0.597170 2.276625
MSN_D1 1.241947 1.448782 0.589151 0.589151 2.283567
MSN_D2 1.277473 1.409101 0.574044 0.574044 2.340041
Vip 1.325455 1.422411 0.579338 0.579338 2.371556
Pvalb 1.265274 1.515272 0.578515 0.578515 2.403175
Sst 1.291474 1.466535 0.571911 0.571911 2.411224
Lamp5_LHX6 1.321902 1.435453 0.563871 0.563871 2.445023
L4_IT 1.228220 1.490808 0.554261 0.554261 2.452842
Lamp5 1.341595 1.420382 0.555473 0.555473 2.486151
EC 1.516187 1.381944 0.542193 0.542193 2.672600
ASC 1.463769 1.367253 0.528202 0.528202 2.679866
VLMC 1.550196 1.322227 0.528425 0.528425 2.717910
OPC 1.614372 1.415822 0.479717 0.479717 3.158317
ODC 1.518158 1.442945 0.465697 0.465697 3.179216
PC 1.683920 1.473416 0.441736 0.441736 3.573785
MGC 1.740897 1.540446 0.326825 0.326825 5.020036
# raw
vmin, vmax = -0.8, 0.8
fig, axes = plt.subplots(((len(leg))//6+1), 6, figsize=(15,15), dpi=300)
fig.subplots_adjust(hspace=0.2, wspace=0.1)
for i,xx in enumerate(leg):
    ax = axes.flatten()[i]
    plot = ax.imshow(np.log2(sad[i]+0.001), cmap='coolwarm', vmin=vmin, vmax=vmax)
    # plot = ax.imshow(sad[i], cmap='coolwarm', vmin=0, vmax=2)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(legname[i], fontsize=15)
    ax.text(40, 45, np.around(compstr.loc[xx, 'AA']/compstr.loc[xx, 'AB'], decimals=3), ha='center', va='center', fontsize=15)
    ax.text(10, 5, np.around(compstr.loc[xx, 'BB']/compstr.loc[xx, 'BA'], decimals=3), ha='center', va='center', fontsize=15)

cbar = plt.colorbar(plot, ax=axes.flatten()[len(leg)])
cbar.set_ticks([vmin,vmax])
cbar.set_label('Normalized Interaction')

for ax in axes.flatten()[len(leg):]:
    ax.axis('off')
    
# plt.tight_layout()
# plt.savefig(f'{indir}/plot/celltype_saddle_raw.pdf', transparent=True)
../../_images/2114e2b743068a03daa681ad122e5fd074eb6f5755c6de8a50815fdd7fbdbc1a.png