#!/usr/bin/env python

import json
import os
import tqdm
import pathlib
import tempfile

import click
from click_didyoumean import DYMGroup

import numpy as np
import pandas as pd
import mdtraj as md

import IMP
import IMP.core
import IMP.algebra
import IMP.atom
import IMP.container
import IMP.kinematics
import RMF
import IMP.pmi.tools
import IMP.pmi.topology
import IMP.rmf

import IMP.bff
import IMP.bff.restraints

from IMP.bff.cli import read_angle_file, WriteRMFFrame, WritePDBFrame


@click.command(
    name="flexfit",
    help="flexible fitting to experimental inter-label distances"
)
@click.option('-i', '--input-pdb', help='Input PDB file.', required=True)
@click.option('-l', '--labeling', help='Input JSON (restraints & flexible amino acids)', required=True)
@click.option('-c', '--score-set', help='Distance set used for scoring', required=True)
@click.option('-f', '--flex-set', help='Name of flexible amino acid set', required=False)
@click.option('-o', '--output', help='Out file either RMF3 or PDB', required=False, default="out.rmf3")
@click.option('-n', '--num-frames', help='Total number of frames.', required=False, default=15000)
@click.option('-p', '--period', help='Steps between write to output file', required=False, default=10)
@click.option('-t', '--temperature', help='Temperature in sampling step.', required=False, default=2.0)
@click.option('-r', '--radii-scaling', help='Radii scaling parameter (0.3 < r < 1.0).', required=False, default=0.7)
@click.option('-m', '--mode', help='Mode (F)ast uses computed the positional distributions '
                                   'of labels once, and updates the mean label positions with'
                                   'reference toe the rigid body the label is attached to. The '
                                   'option (S)low updated / recomputes the AV at each iteration',
              required=False, default="F")
def flexfit(
        input_pdb: str,
        labeling: str,
        score_set: str = '',
        flex_set: str = '',
        num_frames: int = 15000,
        output: str = 'out.rmf3',
        period: int = 10,
        temperature: float = 2.0,
        radii_scaling: float = 0.7,
        mode: str = False,
        do_anneal: bool = True
):
    model = IMP.Model()
    system = IMP.pmi.topology.System(model)
    output_type = str(os.path.splitext(output)[-1])
    if output_type.lower() == ".rmf3":
        optimizer_class = WriteRMFFrame
    else:
        optimizer_class = WritePDBFrame

    if pathlib.Path(output).exists():
        f = RMF.open_rmf_file_read_only(output)
        frame_index = f.get_frames().max - 1
        IMP.rmf.load_frame(f, RMF.FrameID(frame_index))
        rmf_mhd = IMP.rmf.create_hierarchies(f, model)
        root_hier = rmf_mhd[0]
        pro = root_hier.get_children()[0]
        tmp_pdb = tempfile.mktemp('.pdb')
        IMP.atom.write_pdb(pro, tmp_pdb)
        input_pdb = tmp_pdb
        do_anneal = False

    root_hier = system.build()
    mhd = IMP.atom.read_pdb(input_pdb, model, IMP.atom.NonWaterNonHydrogenPDBSelector(), True, False)
    root_hier.add_child(mhd)

    atoms = IMP.atom.get_by_type(mhd, IMP.atom.ATOM_TYPE)

    topology_file_name = IMP.atom.get_data_path("top_heav.lib")
    parameter_file_name = IMP.atom.get_data_path("par.lib")
    ff = IMP.atom.CHARMMParameters(topology_file_name, parameter_file_name)
    ff.add_radii(mhd, radii_scaling)

    topology = ff.create_topology(mhd)
    topology.apply_default_patches()
    topology.add_atom_types(mhd)
    bonds = topology.add_bonds(mhd)

    # create phi/psi joints
    with open(labeling) as fp:
        d = json.load(fp)
        if len(d['FlexFit'].keys()) > 1:
            flex_dict = d['FlexFit'][flex_set]
        else:
            flex_set = list(d['FlexFit'].keys())[0]
            flex_dict = d['FlexFit'][flex_set]
    flexible_residues, extra_bonds = read_angle_file(
        hier=mhd,
        flex_dict=flex_dict
    )
    bonds += extra_bonds
    angles = ff.create_angles(bonds)
    dihedrals = ff.create_dihedrals(bonds)

    pk = IMP.kinematics.ProteinKinematics(mhd, flexible_residues, [])
    joints = pk.get_ordered_joints()
    kf = pk.get_kinematic_forest()
    rbs = pk.get_rigid_bodies()
    kfss = IMP.kinematics.KinematicForestScoreState(kf, rbs, atoms)
    model.add_score_state(kfss)

    # create dofs
    mover = [IMP.kinematics.RevoluteJointMover(model, joints)]

    # Soft sphere score
    ####################
    # prepare exclusions list
    pair_filter = IMP.atom.StereochemistryPairFilter()
    pair_filter.set_bonds(bonds)
    pair_filter.set_angles(angles)
    pair_filter.set_dihedrals(dihedrals)
    # close pair container
    lsc = IMP.container.ListSingletonContainer(model, IMP.get_indexes(atoms))
    cpc = IMP.container.ClosePairContainer(lsc, 15.0)
    cpc.add_pair_filter(pair_filter)
    weight_clash = 1
    score = IMP.core.SoftSpherePairScore(weight_clash)
    pr = IMP.container.create_restraint(score, cpc)
    pr.set_name("stereochemistry")

    # AV score
    ####################
    if mode == "F":
        use_mean_positions = True
    else:
        use_mean_positions = False
    av_restraint = IMP.bff.restraints.AVNetworkRestraintWrapper(
        hier=root_hier,
        score_set=score_set,
        fps_json_fn=labeling,
        mean_position_restraint=use_mean_positions
    )
    av_restraint.add_to_model()
    all_restraints = [av_restraint.rs, pr]
    sf = IMP.core.RestraintsScoringFunction(all_restraints)

    s = IMP.core.MonteCarlo(model)
    s.set_scoring_function(sf)
    s.set_return_best(False)

    sm = IMP.core.SerialMover(mover)
    s.add_mover(sm)

    o_state = optimizer_class(
        filename=output,
        root_hier=root_hier,
        restraints=all_restraints
    )
    o_state.set_period(period)
    s.add_optimizer_state(o_state)

    if do_anneal:
        # Annealing
        num_frames_anneal = 10
        n_anneal = 25
        temp_start = 4 * temperature
        temperatures = np.logspace(np.log(temp_start), np.log(temperature), n_anneal)
        for i, T in enumerate(tqdm.tqdm(temperatures, "Annealing")):
            s.set_kt(T)
            s.optimize(num_frames_anneal)
    # Sampling
    for _ in tqdm.tqdm(range(0, num_frames, period), "Sampling"):
        s.optimize(period)


@click.command(name="rmsd", help="Compute RMSDs ")
@click.option('-f', '--olga-ol4-file', help='Input OL4 file.', required=True)
@click.option('-p', '--rel-path', help='Relative path to PDB files / trajectory (default: ..).', required=False, default="..")
@click.option('-r', '--reference-structure-file', help='Path to reference PDB file (multiple files possible).', required=True, multiple=True)
@click.option('-k', '--structure-key', help='Column name containing the structure information (file name, frame nbr)', required=False, default='structure')
def rmsd(
        olga_ol4_file: str,
        reference_structure_file: str,
        rel_path: str = "..",
        structure_key: str = 'structure',
        n_best: int = 10
):
    """
    Example
    -------


    >>> olga_ol4_file = "~/science/wynton/scratch/bft/eTCSPC/Antibody/expansion/major/run6.ol4"
    >>> reference_structure_file = "~/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/rrt/node_begin.pdb"
    >>> relpath = ".."

    Command line
    ------------

    Without column for RMSD reference calculation

    Major eTCPSC structure in Antibody example

    structure rmsd \
        -f ~/science/wynton/scratch/bft/eTCSPC/Antibody/expansion/major/run_5_6_7.ol4 \
        -p . \
        -r ~/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/nodes1_001.pdb
        -k structure


    Minor eTCPSC structure in Antibody example multiple PDB files for RMSD calculation

    structure rmsd \
        -f ~/science/wynton/scratch/bft/eTCSPC/Antibody/expansion/minor/run_5_6_7_minor.ol4 \
        -p . \
        -r ~/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/nodes1_001.pdb \
        -r ~/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/nodes1_044.pdb \
        -r ~/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/nodes1_045.pdb \
        -k structure

    With provided column for RMSD reference calcualtion

    structure rmsd \
        -f /home/tpeulen/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/experiments/eTCSPC/ucfret/02_structure_search/local_scaling/species_sampling.sort.ol4 \
        -p ../../../../../../rrt/long/ \
        -r ~/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/structures/nodes1_001.pdb \
        -k structure1

    Parameters
    ----------
    olga_ol4_file
    reference_structure_file
    path

    Returns
    -------

    """
    print("Adding RMSDs to OL4 file")
    print("========================")
    print("OL4 File:", olga_ol4_file)
    print("Reference structure:", reference_structure_file)
    print("Structure key:", structure_key)
    df = pd.read_csv(olga_ol4_file, sep="\t")
    path = os.path.join(os.path.dirname(olga_ol4_file), rel_path)
    top_file = reference_structure_file[0]

    # Best index
    best_idx = int(df[['all']].idxmin())
    all_fret_scores = np.array(df['all'])
    sorted_idx = np.argsort(all_fret_scores)

    use_traj = ".dcd" in df[structure_key][0]
    df[['file', 'frame']] = df[structure_key].str.split(' ', 1, expand=True)
    best_file, best_frame = str(df.iloc[best_idx]['file']), int(df.iloc[best_idx]['frame'])
    print("best_score:", float(df['all'][best_idx]))
    print("best_file:", best_file)
    print("best_frame:", best_frame)
    print("Loading trajectory...")
    traj = md.load(os.path.join(path, best_file), top=top_file)

    references = {}
    for reference_file in reference_structure_file:
        # Add new columns for reference_structure
        bn = os.path.basename(reference_file)
        name = "RMSD_%s" % bn
        df[name] = np.nan
        references.update(
            {
                name: md.load(reference_file)
            }
        )

    for i, best_sample in enumerate(sorted_idx[:n_best]):
        best_frame = int(df.iloc[best_sample]['frame'])
        references.update(
            {
                'best_{:d}_frame_{:d}'.format(i, best_frame): traj[best_frame]
            }
        )
    print("Computing RMSD:")
    print("---------------")
    for name in references:
        print("Reference:", name)
        reference = references[name]
        if use_traj:
            traj_file = df[structure_key][0].split(" ")[0]
            # traj_fn = os.path.join(path, traj_file)
            # traj = md.load(traj_fn, top=reference_structure_file)
            rmsds = md.rmsd(traj, reference) * 10.0  # Prefer RMSD in Angstrom
            # Write RMSD values to df column
            for i, row in df.iterrows():
                df.at[i, name] = rmsds[int(row['frame'])]
        else:
            for i, row in df.iterrows():
                pdb_fn = os.path.join(path, row['file'])
                traj = md.load(pdb_fn)
                rmsd = md.rmsd(traj, reference)[0] * 10.0
                print("PDB/RMSD:", pdb_fn, rmsd)
                df.at[i, name] = rmsd
    nf = os.path.join(os.path.splitext(olga_ol4_file)[0] + '.rmsd.ol4')
    df.to_csv(nf, sep='\t')


@click.command(name="decays")
@click.argument('input-file')
@click.option('-o', '--output-path', help='Output path', required=False)
@click.option('-s', '--settings', help='Analysis settings YAML file', required=False)
@click.option('-f', '--fret-efficiency', help='FRET efficiency (mean and SD)', required=False, nargs=2, type=float, default=(None, None))
def auto_model(input_file, output_path=None, settings=None, fret_efficiency=None):
    """

    Parameters
    ----------
    input_file
    output_path
    settings
    fret_efficiency

    Examples
    --------

    autofret decays ~/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/experiments/eTCSPC/A154_B417_run2/sample.yml

    autofret decays /home/tpeulen/science/Papers/00_in_preparation/Bayesian_Fluorescence_Toolkit/Test_case_candiates/01_Antibody/targets/experiments/eTCSPC/A154_B417_run3_with_fret/sample.yml --fret-efficiency 0.519 0.05

    Returns
    -------

    """
    if output_path is None:
        output_path = pathlib.Path(input_file).parent.absolute()
    input_file = pathlib.Path(input_file).absolute()

    print("OUPTUT DIRECTORIES")
    print("===============================")
    OUT_PREANALYSIS = output_path / 'init'
    OUT_PREANALYSIS.mkdir(exist_ok=True)
    OUT_SAMPLING = output_path / 'sampling'
    OUT_SAMPLING.mkdir(exist_ok=True)
    OUT_ANALYSIS = output_path / 'analysis'
    OUT_ANALYSIS.mkdir(exist_ok=True)
    print(output_path)
    print(OUT_PREANALYSIS)
    print(OUT_SAMPLING)
    print(OUT_ANALYSIS)

    print("\nINITIAL DECAY OPTIMIZATION")
    print("===============================")
    main.create_decays.callback(experiment_file=input_file, output_path=OUT_PREANALYSIS)
    for EXPERIMENT in ('DA_DexDem', 'DO_DexDem'):
        print(EXPERIMENT)
        print("===============================")
        FN = OUT_PREANALYSIS / str(EXPERIMENT + ".json")
        main.rebin.callback(input_decay_file=FN)
        main.set_analysis_range.callback(input_decay_file=FN)
        main.data_smooth_linearization.callback(input_decay_file=FN)
        main.estimate_background.callback(input_decay_file=FN)
        main.estimate_lifetime.callback(input_decay_file=FN)
        main.set_score_type.callback(
            score_type='poisson',
            input_decay_file=FN,
            output_decay_file=str(OUT_PREANALYSIS / str(EXPERIMENT + "_n1.json"))
        )
        print("===============================")
        print("#Lifetimes, Score")
        for STEP in range(1, 8):
            decay_json_fn = str(OUT_PREANALYSIS / str(EXPERIMENT + "_n%s" % STEP + ".json"))
            decay_json_next_fn = str(OUT_PREANALYSIS / str(EXPERIMENT + "_n%s" % (STEP + 1) + ".json"))
            decay_png_fn = str(OUT_PREANALYSIS / str(EXPERIMENT + "_n%s" % STEP + ".png"))
            SCORE = main.model_optimize.callback(
                input_decay_file=decay_json_fn,
                output_decay_file=decay_json_fn
            )
            main.plot_json_decay.callback(
                input_decay_file=decay_json_fn,
                output="file"
            )
            main.model_lifetime_add.callback(
                input_decay_file=decay_json_fn,
                output_decay_file=decay_json_next_fn
            )
            print(STEP, SCORE)

        selection_fn = str(OUT_PREANALYSIS / str(EXPERIMENT + "_n*.json"))
        output_decay_json_file = str(OUT_PREANALYSIS / str(EXPERIMENT + "_selected.json"))
        output_decay_png_file = str(OUT_PREANALYSIS / str(EXPERIMENT + "_selected.png"))
        input_files = glob.glob(selection_fn)
        main.model_selection.callback(
            input_files=input_files,
            output_decay_file=output_decay_json_file
        )
        pathlib.Path(output_decay_png_file).rename(OUT_PREANALYSIS / str(EXPERIMENT + "_all_scores.png"))
        main.plot_json_decay.callback(
            input_decay_file=str(OUT_PREANALYSIS / str(EXPERIMENT + "_selected.json")),
            output='file'
        )
    print("\nSAMPLING")
    print("===============================")
    main.sample_decay_distance.callback(
        da_dex_dem=OUT_PREANALYSIS / "DA_DexDem_selected.json",
        do_dex_dem=OUT_PREANALYSIS / "DO_DexDem_selected.json",
        output=OUT_SAMPLING / "sampling.hdf",
        fret_efficiency=fret_efficiency,
        sampling_settings_file=settings
    )


@click.group(cls=DYMGroup)
def cli():
    pass


def main():
    # Sample
    cli.add_command(flexfit)
    cli.add_command(rmsd)
    cli()


if __name__ == "__main__":
    main()
