Source code for af_analysis.plot

import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from cmcrameri import cm

from .analysis import get_pae


[docs] def plot_msa_v2(feature_dict, sort_lines=True, dpi=100): """ Taken from: https://github.com/sokrypton/ColabFold/blob/main/colabfold/plot.py """ seq = feature_dict["msa"][0] # print("len(seq), seq", len(seq), seq) if "asym_id" in feature_dict: Ls = [0] k = feature_dict["asym_id"][0] for i in feature_dict["asym_id"]: if i == k: Ls[-1] += 1 else: Ls.append(1) k = i else: Ls = [len(seq)] Ln = np.cumsum([0] + Ls) try: N = feature_dict["num_alignments"][0] except: N = feature_dict["num_alignments"] # print("asym_id:", feature_dict["asym_id"]) # print(len(feature_dict["asym_id"])) # print(f"Ln: {Ln} Ls:{Ls}") msa = feature_dict["msa"][:N] gap = msa != 21 qid = msa == seq # print("gap.shape:", gap.shape) # for i in range(len(Ls)): # print(f"i: {i} Ls[i]: {Ls[i]} Ln[i+1]: {Ln[i+1]}") # print(f"gap[:, Ln[i]: Ln[i+1]]: {gap[:, Ln[i]: Ln[i+1]]}") # print(f"gap[:, Ln[i]: Ln[i+1]].max(-1): {gap[:, Ln[i]: Ln[i+1]].max(-1)}") gapid = np.stack([gap[:, Ln[i] : Ln[i + 1]].max(-1) for i in range(len(Ls))], -1) lines = [] Nn = [] for g in np.unique(gapid, axis=0): i = np.where((gapid == g).all(axis=-1)) qid_ = qid[i] gap_ = gap[i] seqid = np.stack( [qid_[:, Ln[i] : Ln[i + 1]].mean(-1) for i in range(len(Ls))], -1 ).sum(-1) / (g.sum(-1) + 1e-8) non_gaps = gap_.astype(float) non_gaps[non_gaps == 0] = np.nan if sort_lines: lines_ = non_gaps[seqid.argsort()] * seqid[seqid.argsort(), None] else: lines_ = non_gaps[::-1] * seqid[::-1, None] Nn.append(len(lines_)) lines.append(lines_) Nn = np.cumsum(np.append(0, Nn)) lines = np.concatenate(lines, 0) fig = plt.figure(figsize=(8, 5), dpi=dpi) plt.title("Sequence coverage") plt.imshow( lines, interpolation="nearest", aspect="auto", cmap="rainbow_r", vmin=0, vmax=1, origin="lower", extent=(0, lines.shape[1], 0, lines.shape[0]), ) for i in Ln[1:-1]: plt.plot([i, i], [0, lines.shape[0]], color="black") for j in Nn[1:-1]: plt.plot([0, lines.shape[1]], [j, j], color="black") plt.plot((np.isnan(lines) == False).sum(0), color="black") plt.xlim(0, lines.shape[1]) plt.ylim(0, lines.shape[0]) plt.colorbar(label="Sequence identity to query") plt.xlabel("Positions") plt.ylabel("Sequences") return fig
[docs] def show_info( data_af, cmap=cm.vik, score_list=["pLDDT", "pTM", "ipTM", "ranking_confidence"] ): """ Use with ``` %matplotlib widget ``` """ model_widget = widgets.IntSlider( value=1, min=1, max=len(data_af.df), step=1, description="model:", disabled=False, ) display(model_widget) rank_num = 1 fig, (ax_plddt, ax_pae) = plt.subplots(1, 2, figsize=(10, 4)) plddt_array = data_af.get_plddt(rank_num - 1) (plddt_plot,) = ax_plddt.plot(plddt_array) query = data_af.df.iloc[model_widget.value - 1]["query"] data_file = data_af.df.iloc[model_widget.value - 1]["data_file"] vline_plot = ax_plddt.vlines( np.cumsum(data_af.chain_length[query][:-1]), ymin=0, ymax=100.0, colors="black", ) ax_plddt.set_ylim(0, 100) res_max = sum(data_af.chain_length[query]) ax_plddt.set_xlim(0, res_max) ax_plddt.set_xlabel("Residue") ax_plddt.set_ylabel("predicted LDDT") pae_array = get_pae(data_file) pae_plot = ax_pae.imshow( pae_array, cmap=cmap, vmin=0.0, vmax=30.0, ) vline_pae = ax_pae.vlines( np.cumsum(data_af.chain_length[query][:-1]), ymin=-0.5, ymax=res_max, colors="yellow", ) hline_pae = ax_pae.hlines( np.cumsum(data_af.chain_length[query][:-1]), xmin=-0.5, xmax=res_max, colors="yellow", ) ax_pae.set_xlim(-0.5, res_max - 0.5) ax_pae.set_ylim(res_max - 0.5, -0.5) chain_pos = [] len_sum = 0 for longueur in data_af.chain_length[query]: chain_pos.append(len_sum + longueur / 2) len_sum += longueur ax_pae.set_yticks(chain_pos) ax_pae.set_yticklabels(data_af.chains[query]) plt.show(fig) # out_score = widgets.Output(layout={'border': '1px solid black'}) out_score = widgets.HTML() display(out_score) pattern = "<p style='display: inline-block; width:100px'> <strong>{score_name:15} : </strong> {score_value:7.2f} </p>" for score in score_list: if score in data_af.df.columns: if data_af.df.iloc[model_widget.value - 1][score] is None: continue out_score.value += pattern.format( score_name=score, score_value=data_af.df.iloc[model_widget.value - 1][score], ) # (f"<div> <strong>{score:15} : </strong> {data_af.df.iloc[model_widget.value - 1][score]:7.2f} </div>") def update_model(change): rank_num = model_widget.value # print("Update") plddt_array = data_af.get_plddt(rank_num - 1) res_num = len(plddt_array) plddt_plot.set_data(range(res_num), plddt_array) ax_plddt.set_xlim(0, len(plddt_array)) query = data_af.df.iloc[rank_num - 1]["query"] vline_plot.set_segments( [ np.array([[x, 0], [x, 100]]) for x in np.cumsum(data_af.chain_length[query][:-1]) ] ) # ax_plddt.set_title(self.chain_length[query][:-1]) data_file = data_af.df.iloc[rank_num - 1]["data_file"] pae_array = get_pae(data_file) pae_plot.set_extent((0, res_num, 0, res_num)) pae_plot.set_data(pae_array) ax_pae.set_xlim(0, res_num) ax_pae.set_ylim(0, res_num) vline_pae.set_segments( [ np.array([[x, -0.5], [x, res_num]]) for x in np.cumsum(data_af.chain_length[query][:-1]) ] ) hline_pae.set_segments( [ np.array([[-0.5, res_num - x], [res_num, res_num - x]]) for x in np.cumsum(data_af.chain_length[query][:-1]) ] ) chain_pos = [] len_sum = 0 for longueur in data_af.chain_length[query]: chain_pos.append(res_num - (len_sum + longueur / 2)) len_sum += longueur ax_pae.set_yticks(chain_pos) ax_pae.set_yticklabels(data_af.chains[query]) fig.canvas.draw() new_out_score = "" for score in score_list: if score in data_af.df.columns: # new_out_score += (f"<div> <strong>{score:15} : </strong> {data_af.df.iloc[model_widget.value - 1][score]:7.2f} </div>") new_out_score += pattern.format( score_name=score, score_value=data_af.df.iloc[model_widget.value - 1][score], ) out_score.value = new_out_score # out_score.clear_output() # with out_score: # for score in score_list: # if score in data_af.df.columns: # print(f"{score:15} : {data_af.df.iloc[model_widget.value - 1][score]:7.2f}") model_widget.observe(update_model, names="value")