#!/usr/bin/env python

import IMP
import IMP.core as core
import IMP.atom as atom
import IMP.em2d as em2d
import os
import sys
import time
import logging
log = logging.getLogger("score_model")


def score_model(complete_fn_model,
                images_sel_file,
                pixel_size,
                n_projections=20,
                resolution=1,
                images_per_batch=250):
    """ Score a model

    Scores a model against the images in the selection file.
    Reads the images in batchs to avoid memory problems
    resolution and pixel_size are used for generating projections of the model
    The finder is an em2d.ProjectionFinder used for projection matching and optimizations
    """
    print("SCORING MODEL: " + complete_fn_model)
    cwd = os.getcwd()

    images_dir, nil = os.path.split(images_sel_file)
    images_names = em2d.read_selection_file(images_sel_file)
    n_images = len(images_names)
    if(n_images == 0):
        raise ValueError(" Scoring with a empty set of images")

    # TYPICAL OPTIMIZER PARAMETERS
    params = em2d.Em2DRestraintParameters(pixel_size, resolution)
    params.coarse_registration_method = em2d.ALIGN2D_PREPROCESSING
    params.optimization_steps = 100
    #params.optimization_steps = 4      # commented by SJ, to improve the optimization process
    params.simplex_initial_length = 0.1
    params.simplex_minimum_size = 0.02
    params.save_match_images = True
    score_function = em2d.EM2DScore()
    finder = em2d.ProjectionFinder()
    finder.setup(score_function, params)

    # Get the number of rows and cols from the 1st image
    srw = em2d.SpiderImageReaderWriter()
    test_imgs = em2d.read_images(
        [os.path.join(images_dir, images_names[0])], srw)
    rows = test_imgs[0].get_header().get_number_of_columns()
    cols = test_imgs[0].get_header().get_number_of_rows()

    model = IMP.Model()
    ssel = atom.ATOMPDBSelector()
    prot = atom.read_pdb(complete_fn_model, model, ssel)
    particles = IMP.core.get_leaves(prot)
    # generate projections
    proj_params = em2d.get_evenly_distributed_registration_results(
        n_projections)
    opts = em2d.ProjectingOptions(pixel_size, resolution)
    projections = em2d.get_projections(
        particles,
        proj_params,
        rows,
        cols,
        opts)

    finder.set_model_particles(particles)
    finder.set_projections(projections)
    optimized_solutions = 20
    #optimized_solutions = 2    # commented by SJ, to improve the optimization process
    finder.set_fast_mode(optimized_solutions)
    # read the images in blocks to avoid memory problems
    all_registration_results = []
    init_set = 0
    init_time = time.time()
    while(init_set < n_images):
        end_set = min(init_set + images_per_batch, n_images)
        if(images_dir != ""):
            os.chdir(images_dir)
        subjects = em2d.read_images(images_names[init_set:end_set], srw)
        # register
        finder.set_subjects(subjects)
        os.chdir(cwd)
        finder.get_complete_registration()
        # Recover the registration results:
        registration_results = finder.get_registration_results()
        for reg in registration_results:
            all_registration_results.append(reg)
        init_set += images_per_batch
    os.chdir(cwd)
    em2d.write_registration_results(
        "registration.params",
        all_registration_results)
    print("score_model: time complete registration %.2f"
          % (time.time() - init_time))
    print("coarse registration time %2.f"
          % finder.get_coarse_registration_time())
    print("fine registration time %.2f" % finder.get_fine_registration_time())
    return all_registration_results


def parse_args():
    p = IMP.ArgumentParser(description="EMageFit scoring")
    p.add_argument("model", help="PDB file of the model to score")
    p.add_argument("selfile", help="Selection file containing the images "
                                   "used for scoring")
    p.add_argument("pxsize", type=float,
                   help="Pixel size of the images in Angstrom/pixel")
    p.add_argument("nproj", type=int,
                   help="Number of projections of the model used for start "
                        "registering")
    p.add_argument("res", type=float,
                   help="Resolution to generate the projections")
    p.add_argument("nimg", type=int,
                   help="Images per batch used when scoring a lot of images "
                        "(to avoid memory problems)")
    return p.parse_args()


if __name__ == "__main__":
    args = parse_args()
    complete_fn_model = args.model
    images_sel_file = args.selfile
    pixel_size = args.pxsize
    n_projections = args.nproj
    resolution = args.res
    images_per_batch = args.nimg

    all_regs = score_model(complete_fn_model,
                           images_sel_file,
                           pixel_size,
                           n_projections,
                           resolution,
                           images_per_batch)
    for i, reg in enumerate(all_regs):
        print("Score for image %s: %f" % (i, reg.get_score()))
    print("Global score " + str(em2d.get_global_score(all_regs)))
