Source code for af_analysis.format.default

#!/usr/bin/env python3
# coding: utf-8

import os
import re
import pathlib
import logging
from tqdm.auto import tqdm
import pandas as pd

logger = logging.getLogger()
logger.setLevel(logging.INFO)

weigths = [
    "alphafold2",
    "alphafold2_ptm",
    "alphafold2_multimer_v1",
    "alphafold2_multimer_v2",
    "alphafold2_multimer_v3",
    "multimer_v1",
    "multimer_v2",
    "multimer_v3",
]


[docs] def read_dir(directory): """Extract pdb list from a directory.""" logger.info(f"Reading {directory}") log_dict_list = [] for file in os.listdir(directory): # We only consider .cif files that do not start with a dot (hidden files) if file.endswith(".pdb") and not file.startswith("."): token = file.split("_") if token[0] == "ranked": continue for i, t in enumerate(token): if t in ["relaxed", "unrelaxed"]: start_index = i # print(start_index) break if start_index == 0: path_dir = pathlib.Path(directory) name_token = ( path_dir.parent.name if not path_dir.is_dir() else path_dir.name ) state = token[start_index] rank = None weight = "_".join(token[start_index + 3 : -2]) assert weight in weigths model = int(token[start_index + 2]) seed = int(token[-1][:-4]) else: name_token = "_".join(token[:start_index]) state = token[start_index] rank = int(token[start_index + 2]) weight = "_".join(token[start_index + 3 : -4]) assert weight in weigths model = int(token[-3]) seed = int(token[-1][:-4]) log_dict_list.append( { "pdb": os.path.join(directory, file), "query": name_token, "rank": rank, "state": state, "seed": seed, "model": model, "weight": weight, } ) log_pd = pd.DataFrame(log_dict_list) return log_pd
[docs] def add_json(log_pd, directory, verbose=True): """Find json files in the directory. Parameters ---------- log_pd : pandas.DataFrame Dataframe containing the information extracted from the `log.txt` file. directory : str Path to the directory containing the json files. verbose : bool If True, show a progress bar. If False, no progress bar is shown. Returns ------- None The `log_pd` dataframe is modified in place. """ logger.info(f"Extracting json files location") raw_list = os.listdir(directory) file_list = [] for file in raw_list: if file.endswith(".json") or file.endswith(".pkl"): file_list.append(file) json_list = [] # Get the last recycle: if "recycle" not in log_pd.columns: last_recycle = log_pd.groupby(["query", "seed", "model", "weight", "state"]) else: last_recycle = ( log_pd.groupby(["query", "seed", "model", "weight"])["recycle"].transform( "max" ) == log_pd["recycle"] ) disable = False if verbose else True for i, last in tqdm( enumerate(last_recycle), total=len(last_recycle), disable=disable ): row = log_pd.iloc[i] reg = rf"{row['query']}_scores_.*_{row['weight']}_model_{row['model']}_seed_{row['seed']:03d}\.json" r = re.compile(reg) res = list(filter(r.match, file_list)) reg_pkl = ( rf"result_model_{row['model']}_{row['weight']}_pred_{row['seed']}\.pkl" ) r_pkl = re.compile(reg_pkl) res_pkl = list(filter(r_pkl.match, file_list)) if len(res) == 1: json_list.append(os.path.join(directory, res[0])) elif len(res) == 0: if len(res_pkl) == 1: json_list.append(os.path.join(directory, res_pkl[0])) else: logger.warning(f"Not founded : {reg} or {reg_pkl}") json_list.append(None) else: logger.warning(f"Multiple json file for {reg}: {res}") json_list.append(res[0]) log_pd.loc[:, "data_file"] = json_list