Source code for af_analysis.data

#!/usr/bin/env python3
# coding: utf-8

import os
import numpy as np
import pandas as pd
import json
import pdb_cpp
import seaborn as sns
import matplotlib.pyplot as plt
from cmcrameri import cm
from tqdm.auto import tqdm
import json
import pickle
import logging
import ipywidgets as widgets

logger = logging.getLogger(__name__)

from .format import (
    colabfold_1_5,
    af3_webserver,
    afpulldown,
    boltz1,
    chai1,
    massivefold,
    default,
    af3_local,
)
from . import sequence, plot
from .analysis import get_pae, extract_fields_file


# Autorship information
__author__ = "Alaa Reguei, Samuel Murail"
__copyright__ = "Copyright 2023, RPBS"
__credits__ = ["Samuel Murail", "Alaa Reguei"]
__license__ = "GNU General Public License version 2"
__version__ = "0.2.0"
__maintainer__ = "Samuel Murail"
__email__ = "samuel.murail@u-paris.fr"
__status__ = "Beta"

# Logging
logger = logging.getLogger(__name__)


plddt_main_atom_list = [
    "CA",
    "P",
    "ZN",
    "MG",
    "CL",
    "CA",
    "NA",
    "MN",
    "K",
    "FE",
    "CU",
    "CO",
]


def _flatten_1d(array_like):
    array = np.asarray(array_like)
    if array.ndim > 1:
        array = array.reshape(-1)
    return array


def _unique_preserve_order(values):
    return list(dict.fromkeys(values))


[docs] class Data: """Data class Parameters ---------- verbose : bool Print progress bar during analysis. dir : str Path to the directory containing the `log.txt` file. format : str Format of the data. df : pandas.DataFrame Dataframe containing the information extracted from the `log.txt` file. chains : dict Dictionary containing the chains of each query. chain_length : dict Dictionary containing the length of each chain of each query. Methods ------- read_directory(directory, keep_recycles=False) Read a directory. export_csv(path) Export the dataframe to a csv file. import_csv(path) Import a csv file to the dataframe. add_json() Add json files to the dataframe. extract_data() Extract json/npz files to the dataframe. add_pdb() Add pdb files to the dataframe. add_fasta(csv) Add fasta sequence to the dataframe. keep_last_recycle() Keep only the last recycle for each query. plot_maxscore_as_col(score, col, hue='query') Plot the maxscore as a function of a column. plot_pae(index, cmap=cm.vik) Plot the PAE matrix. plot_plddt(index_list) Plot the pLDDT. show_3d(index) Show the 3D structure. plot_msa(filter_qid=0.15, filter_cov=0.4) Plot the msa from the a3m file. show_plot_info() Show the plot info. """ def __init__( self, directory=None, data_dict=None, csv=None, verbose=True, format=None ): """ """ self.verbose = verbose if directory is not None: self.read_directory(directory, verbose=verbose, format=format) elif csv is not None: self.format = "csv" self.import_csv(csv) elif data_dict is not None: assert "pdb" in data_dict.keys() assert "query" in data_dict.keys() assert "data_file" in data_dict.keys() self.df = pd.DataFrame(data_dict) self.dir = None self.df["format"] = "custom" self.set_chain_length()
[docs] def read_directory(self, directory, keep_recycles=False, verbose=True, format=None): """Read a directory. If the directory contains a `log.txt` file, the format is set to `colabfold_1.5`. Parameters ---------- directory : str Path to the directory containing the `log.txt` file. keep_recycles : bool Keep only the last recycle for each query. verbose : bool Print information about the directory. Returns ------- None """ self.dir = directory if format == "colabfold_1.5" or os.path.isfile( os.path.join(directory, "log.txt") ): self.format = "colabfold_1.5" self.df = colabfold_1_5.read_log(directory, keep_recycles) self.df["format"] = "colabfold_1.5" self.add_pdb(verbose=verbose) self.add_json(verbose=verbose) elif format == "AF3_local": self.format = "AF3_local" self.df = af3_local.read_dir(directory) self.df["format"] = "AF3_local" elif format == "AF3_webserver" or os.path.isfile( os.path.join(directory, "terms_of_use.md") ): self.format = "AF3_webserver" self.df = af3_webserver.read_dir(directory) self.df["format"] = "AF3_webserver" elif format == "alphapulldown": self.format = "alphapulldown" self.df = afpulldown.read_dir(directory) self.df["format"] = "alphapulldown" elif format == "alphapulldown_full": self.format = "alphapulldown" self.df = afpulldown.read_full_dir(directory) self.df["format"] = "alphapulldown" elif format == "boltz1" or ( os.path.isdir(os.path.join(directory, "predictions")) ): self.format = "boltz1" self.df = boltz1.read_dir(directory) self.df["format"] = "boltz1" elif ( format == "chai1" or os.path.isfile(os.path.join(directory, "msa_depth.pdf")) or os.path.isfile(os.path.join(directory, "pae.rank_0.npy")) ): self.format = "chai1" self.df = chai1.read_dir(directory) self.df["format"] = "chai1" elif format == "massivefold": self.format = "massivefold" self.df = massivefold.read_dir(directory) self.df["format"] = "massivefold" elif format == "full_massivefold": self.format = "massivefold" self.df = massivefold.read_full_directory(directory) self.df["format"] = "massivefold" else: self.format = "default" self.df = default.read_dir(directory) self.df["format"] = "default" self.add_json(verbose=verbose) self.set_chain_length()
[docs] def set_chain_length(self): """Find chain information from the dataframe. Parameters ---------- None Returns ------- None """ nuc_list = ["DA", "DC", "DT", "DG", "A", "C", "U", "G", "T"] aa_list = [ "ALA", "ARG", "ASN", "ASP", "CYS", "GLU", "GLN", "GLY", "HIS", "ILE", "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", ] aa_nuc_list = aa_list + nuc_list def get_type(resnames): for resname in resnames: if resname in nuc_list: return "nucleic_acid" for resname in resnames: if resname in aa_list: return "protein" return "ligand" self.chains = {} self.chain_length = {} self.chain_type = {} for querie in self.df["query"].unique(): # print(querie, self.df[self.df['query'] == querie]) query_rows = self.df[self.df["query"] == querie] first_model = None for _, qrow in query_rows.iterrows(): pdb_path = qrow["pdb"] if pdb_path and os.path.isfile(pdb_path): coor = pdb_cpp.Coor(pdb_path) if coor.models: first_model = coor break if first_model is None or not first_model.models: logger.warning(f"No valid PDB found for query {querie}, skipping chain info") self.chains[querie] = [] self.chain_length[querie] = [] self.chain_type[querie] = [] continue model = first_model.models[0] chain_arr = np.asarray(model.chain_str) uniq_resid = _flatten_1d(model.uniq_resid) resname = np.asarray(model.resname_str) self.chains[querie] = _unique_preserve_order(chain_arr.tolist()) # self.chain_length[querie] = [ # len(np.unique(uniq_resid[chain_arr == chain])) # for chain in self.chains[querie] # ] self.chain_length[querie] = [] for chain in self.chains[querie]: # check that all resname are in aa)=_list or nuc_list resname_chain = resname[chain_arr == chain] if all(res in aa_nuc_list for res in resname_chain): self.chain_length[querie].append( len(np.unique(uniq_resid[chain_arr == chain])) ) else: self.chain_length[querie].append(sum(chain_arr == chain)) self.chain_type[querie] = [ get_type(resname[chain_arr == chain]) for chain in self.chains[querie] ]
[docs] def export_csv(self, path): """Export the dataframe to a csv file. Parameters ---------- path : str Path to the csv file. Returns ------- None """ self.df.to_csv(path, index=False)
[docs] def import_csv(self, path): """Import a csv file to the dataframe. Parameters ---------- path : str Path to the csv file. Returns ------- None """ import ast self.df = pd.read_csv(path) self.dir = os.path.dirname(self.df["pdb"][0]) # Restore array/list columns that were serialized as strings by export_csv. _array_cols = ["LIS", "LIA", "ipTM_d0_matrix", "ipSAE_matrix"] for col in _array_cols: if col in self.df.columns: def _parse(v): if isinstance(v, str): try: parsed = ast.literal_eval(v) if isinstance(parsed, list): return np.array(parsed) except (ValueError, SyntaxError): pass return v self.df[col] = self.df[col].apply(_parse) self.set_chain_length()
[docs] def add_json(self, verbose=True): """Add json files to the dataframe. Parameters ---------- None Returns ------- None """ if self.format == "colabfold_1.5": colabfold_1_5.add_json(self.df, self.dir, verbose=verbose) if self.format == "default": default.add_json(self.df, self.dir, verbose=verbose)
[docs] def extract_data(self): """Extract json/npz files to the dataframe. Parameters ---------- None Returns ------- None """ index_list = [] data_list = [] for index, data_path in zip(self.df.index, self.df["data_file"]): if data_path is not None: if data_path.endswith(".json"): with open(data_path, "r") as f: data = json.load(f) data_list.append(data) index_list.append(index) elif data_path.endswith(".npz"): data_npz = np.load(data_path) data = {} for key in data_npz.keys(): data[key] = data_npz[key] data_list.append(data) index_list.append(index) new_column = {} for key in data_list[0].keys(): new_column[key] = [] for data in data_list: for key in data.keys(): new_column[key].append(data[key]) for key in new_column.keys(): self.df.loc[:, key] = None self.df.loc[index_list, key] = pd.Series(new_column[key], index=index_list)
[docs] def extract_fields(self, fields, disable=False): """Extract fields from data files to the dataframe. Parameters ---------- fields : list List of fields to extract. disable : bool Disable the progress bar. Returns ------- None """ values_list = [] for field in fields: values_list.append([]) for data_path in tqdm( self.df["data_file"], total=len(self.df["data_file"]), disable=disable ): if data_path is not None: local_values = extract_fields_file(data_path, fields) for i in range(len(fields)): values_list[i].append(local_values[i]) else: for i in range(len(fields)): values_list[i].append(None) for i, field in enumerate(fields): self.df[field] = None new_col = pd.Series(values_list[i], index=self.df.index) self.df.loc[:, field] = new_col
# print(len(values_list[i]), len(self.df)) # self.df.loc[:, field] = values_list[i]
[docs] def add_pdb(self, verbose=True): """Add pdb files to the dataframe. Parameters ---------- None Returns ------- None """ if self.format == "colabfold_1.5": colabfold_1_5.add_pdb(self.df, self.dir, verbose=verbose)
[docs] def add_fasta(self, csv): """Add fasta sequence to the dataframe. Parameters ---------- csv : str Path to the csv file containing the fasta sequence. Returns ------- None """ if self.format == "colabfold_1.5": colabfold_1_5.add_fasta(self.df, csv)
[docs] def keep_last_recycle(self): """Keep only the last recycle for each query.""" idx = ( self.df.groupby(["query", "seed", "model", "weight"])["recycle"].transform( "max" ) == self.df["recycle"] ) self.df = self.df[idx]
[docs] def plot_maxscore_as_col(self, score, col, hue="query"): col_list = self.df[col].unique() query_list = self.df[hue].unique() # print(col_list) # print(query_list) out_list = [] for query in query_list: # print(query) query_pd = self.df[self.df[hue] == query] for column in col_list: # print(column) # ~print() col_pd = query_pd[query_pd[col] <= column] # print(col_pd[score]) # print(column, len(col_pd)) # print(col, col_pd.columns) if len(col_pd) > 0: out_list.append( {hue: query, score: col_pd[score].max(), col: column} ) # print(column, len(col_pd), col_pd[score].max()) max_pd = pd.DataFrame(out_list) fig, ax = plt.subplots() sns.lineplot(max_pd, x=col, y=score, hue=hue) return (fig, ax)
[docs] def plot_pae(self, index, cmap=cm.vik): row = self.df.iloc[index] if row["data_file"] is None: return None pae_array = get_pae(row["data_file"]) fig, ax = plt.subplots() res_max = sum(self.chain_length[row["query"]]) img = ax.imshow( pae_array, cmap=cmap, vmin=0.0, vmax=30.0, ) plt.hlines( np.cumsum(self.chain_length[row["query"]][:-1]) - 0.5, xmin=-0.5, xmax=res_max, colors="black", ) plt.vlines( np.cumsum(self.chain_length[row["query"]][:-1]) - 0.5, ymin=-0.5, ymax=res_max, colors="black", ) plt.xlim(-0.5, res_max - 0.5) plt.ylim(res_max - 0.5, -0.5) chain_pos = [] len_sum = 0 for longueur in self.chain_length[row["query"]]: chain_pos.append(len_sum + longueur / 2) len_sum += longueur ax.set_yticks(chain_pos) ax.set_yticklabels(self.chains[row["query"]]) cbar = plt.colorbar(img) cbar.set_label("Predicted Aligned Error (Å)", rotation=270) cbar.ax.get_yaxis().labelpad = 15 return (fig, ax)
[docs] def get_plddt(self, index): """Extract the pLDDT array either from the pdb file or form the json/plddt files. Parameters ---------- index : int Index of the dataframe. Returns ------- np.array pLDDT array. """ row = self.df.iloc[index] if row["format"] in [ "AF3_webserver", "AF3_local", "csv", "alphapulldown", "chai1", "massivefold", ]: model = pdb_cpp.Coor(row["pdb"]) m = model.models[0] chain_arr = np.asarray(m.chain_str) name_arr = np.asarray(m.name_str) beta_arr = _flatten_1d(m.beta) query = row["query"] chains = self.chains.get(query, []) chain_types = self.chain_type.get(query, []) # For ligand chains PAE has one row per heavy atom, so pLDDT must match. if chains and chain_types and "ligand" in chain_types: parts = [] for chain_id, ctype in zip(chains, chain_types): c_mask = chain_arr == chain_id c_names = name_arr[c_mask] c_beta = beta_arr[c_mask] if ctype == "ligand": # All heavy atoms (no hydrogen) heavy = np.array([not n.startswith("H") for n in c_names]) parts.append(c_beta[heavy]) else: parts.append(c_beta[np.isin(c_names, plddt_main_atom_list)]) return np.concatenate(parts) if parts else np.array([]) plddt_array = beta_arr[np.isin(name_arr, plddt_main_atom_list)] return plddt_array if row["format"] in ["boltz1"]: data_npz = np.load(row["plddt"]) plddt_array = data_npz["plddt"] return plddt_array * 100 if row["data_file"] is None: return None elif row["data_file"].endswith(".json"): with open(row["data_file"]) as f: local_json = json.load(f) if "plddt" in local_json: plddt_array = np.array(local_json["plddt"]) else: return None elif row["data_file"].endswith(".npz"): data_npz = np.load(row["data_file"]) if "plddt" in data_npz: plddt_array = data_npz["plddt"] else: return None elif row["data_file"].endswith(".pkl"): try: data_pkl = np.load(row["data_file"], allow_pickle=True) except pickle.UnpicklingError as e: logger.error(f"Error loading pLDDT from {row['data_file']}: {e}") return None plddt_array = data_pkl["plddt"] return plddt_array
[docs] def plot_plddt(self, index_list=None): if index_list is None: index_list = range(len(self.df)) fig, ax = plt.subplots() for index in index_list: plddt_array = self.get_plddt(index) plt.plot(np.arange(1, len(plddt_array) + 1), plddt_array) plt.vlines( np.cumsum(self.chain_length[self.df.iloc[index_list[0]]["query"]][:-1]), ymin=0, ymax=100.0, colors="black", ) plt.ylim(0, 100) plt.xlim(0, sum(self.chain_length[self.df.iloc[index_list[0]]["query"]])) plt.xlabel("Residue") plt.ylabel("predicted LDDT") return (fig, ax)
[docs] def show_3d(self, index): row = self.df.iloc[index] if row["pdb"] is None: return (None, None) import nglview as nv # Bug with show_file # view = nv.show_file(row['pdb']) view = nv.show_structure_file(row["pdb"]) # view.add_component(ref_coor[0]) # view.clear_representations(1) # view[1].add_cartoon(selection="protein", color='blue') # view[1].add_licorice(selection=":A", color='blue') # view[0].add_licorice(selection=":A") return view
[docs] def plot_msa(self, filter_qid=0.15, filter_cov=0.4): """ Plot the msa from the a3m file. Parameters ---------- filter_qid : float Minimal sequence identity to keep a sequence. filter_cov : float Minimal coverage to keep a sequence. Returns ------- None ..Warning only tested with colabfold 1.5 """ raw_list = os.listdir(self.dir) file_list = [] for file in raw_list: if file.endswith(".a3m"): file_list.append(file) for a3m_file in file_list: logger.info(f"Reading MSA file:{a3m_file}") querie = a3m_file.split("/")[-1].split(".")[0] a3m_lines = open(os.path.join(self.dir, a3m_file), "r").readlines()[1:] seqs, mtx, nams = sequence.parse_a3m( a3m_lines=a3m_lines, filter_qid=filter_qid, filter_cov=filter_cov ) logger.info(f"- Keeping {len(seqs):6} sequences for plotting.") feature_dict = {} feature_dict["msa"] = sequence.convert_aa_msa(seqs) feature_dict["num_alignments"] = len(seqs) if len(seqs) == sum(self.chain_length[querie]): feature_dict["asym_id"] = [] for i, chain_len in enumerate(self.chain_length[querie]): feature_dict["asym_id"] += [i + 1.0] * chain_len feature_dict["asym_id"] = np.array(feature_dict["asym_id"]) fig = plot.plot_msa_v2(feature_dict) plt.show()
[docs] def count_msa_seq(self): """ Count for each chain the number of sequences in the MSA. Parameters ---------- None Returns ------- None ..Warning only tested with colabfold 1.5 """ raw_list = os.listdir(self.dir) file_list = [] for file in raw_list: if file.endswith(".a3m"): file_list.append(file) alignement_len = {} for a3m_file in file_list: logger.info(f"Reading MSA file:{a3m_file}") querie = a3m_file.split("/")[-1].split(".")[0] a3m_lines = open(os.path.join(self.dir, a3m_file), "r").readlines()[1:] seqs, mtx, nams = sequence.parse_a3m( a3m_lines=a3m_lines, filter_qid=0, filter_cov=0 ) feature_dict = {} feature_dict["msa"] = sequence.convert_aa_msa(seqs) feature_dict["num_alignments"] = len(seqs) seq_dict = {} for chain in self.chains[querie]: seq_dict[chain] = 0 chain_len_list = self.chain_length[querie] chain_list = self.chains[querie] seq_len = sum(chain_len_list) # Treat the cases of homomers # I compare the length of each sequence with the other ones # It is wrong and should be FIXED # The original sequence should be retrieved from eg. the pdb file if len(seqs[0]) != seq_len: new_chain_len = [] new_chain_list = [] for i, seq_len in enumerate(chain_len_list): if seq_len not in chain_len_list[:i]: new_chain_len.append(seq_len) new_chain_list.append(chain_list[i]) chain_len_list = new_chain_len chain_list = new_chain_list seq_len = sum(chain_len_list) assert ( len(seqs[0]) == seq_len ), f"len(seqs[0])={len(seqs[0])} != seq_len={seq_len}" for seq in seqs: start = 0 for i, num in enumerate(chain_len_list): gap_num = seq[start : start + num].count("-") if gap_num < num: seq_dict[chain_list[i]] += 1 start += num alignement_len[ querie ] = seq_dict # [seq_dict[chain] for chain in self.chains[querie]] return alignement_len
[docs] def show_plot_info(self, cmap=cm.vik): """ Need to solve the issue with: ``` %matplotlib ipympl ``` plots don´t update when changing the model number. """ model_widget = widgets.IntSlider( value=1, min=1, max=len(self.df), step=1, description="model:", disabled=False, ) display(model_widget) def show_model(rank_num): fig, (ax_plddt, ax_pae) = plt.subplots(1, 2, figsize=(10, 4)) plddt_array = self.get_plddt(rank_num - 1) (plddt_plot,) = ax_plddt.plot(plddt_array) query = self.df.iloc[model_widget.value - 1]["query"] data_file = self.df.iloc[model_widget.value - 1]["data_file"] ax_plddt.vlines( np.cumsum(self.chain_length[query][:-1]), ymin=0, ymax=100.0, colors="black", ) ax_plddt.set_ylim(0, 100) res_max = sum(self.chain_length[query]) ax_plddt.set_xlim(0, res_max) ax_plddt.set_xlabel("Residue") ax_plddt.set_ylabel("predicted LDDT") pae_array = get_pae(data_file) ax_pae.imshow( pae_array, cmap=cmap, vmin=0.0, vmax=30.0, ) ax_pae.vlines( np.cumsum(self.chain_length[query][:-1]), ymin=-0.5, ymax=res_max, colors="yellow", ) ax_pae.hlines( np.cumsum(self.chain_length[query][:-1]), xmin=-0.5, xmax=res_max, colors="yellow", ) ax_pae.set_xlim(-0.5, res_max - 0.5) ax_pae.set_ylim(res_max - 0.5, -0.5) chain_pos = [] len_sum = 0 for longueur in self.chain_length[query]: chain_pos.append(len_sum + longueur / 2) len_sum += longueur ax_pae.set_yticks(chain_pos) ax_pae.set_yticklabels(self.chains[query]) plt.show(fig) output = widgets.Output(layout={"width": "95%"}) display(output) with output: show_model(model_widget.value) # logger.info(results['metric'][0][rank_num - 1]['print_line']) def on_value_change(change): output.clear_output() with output: show_model(model_widget.value) model_widget.observe(on_value_change, names="value")
[docs] def concat_data(data_list): """Concatenate data from a list of Data objects. Parameters ---------- data_list : list List of Data objects. Returns ------- Data Concatenated Data object. """ concat = Data(directory=None, csv=None) concat.df = pd.concat([data.df for data in data_list], ignore_index=True) concat.chains = data_list[0].chains concat.chain_length = data_list[0].chain_length concat.chain_type = data_list[0].chain_type concat.format = data_list[0].format for i in range(1, len(data_list)): concat.chains.update(data_list[i].chains) concat.chain_length.update(data_list[i].chain_length) concat.chain_type.update(data_list[i].chain_type) return concat
[docs] def read_multiple_alphapulldown(directory): """Read multiple directories containing AlphaPulldown data. Parameters ---------- directory : str Path to the directory containing the directories. Returns ------- Data Concatenated Data object. """ dir_list = [ name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name)) ] data_list = [] for dir in dir_list: if "ranking_debug.json" in os.listdir(os.path.join(directory, dir)): data_list.append(Data(os.path.join(directory, dir))) if len(data_list) == 0: raise ValueError("No AlphaPulldown data found in the directory.") return concat_data(data_list)