import jax
import jax.numpy as jnp
import IMP.atom
import jax.tree_util
from dataclasses import dataclass
import IMP._jax_util


# Conversion from derivatives (in kcal/mol/A) to acceleration (A/fs/fs)
_deriv_to_acceleration = -4.1868e-4


def _propagate_coordinates(jm, indexes, mass, time_step, velocity_cap=None):
    linvel = jm['linvel'].at[indexes]
    dcoord = jm["xyz'"][indexes]
    v = linvel.get() + time_step * 0.5 * dcoord * _deriv_to_acceleration / mass
    if velocity_cap is not None:
        v = jnp.clip(v, -velocity_cap, velocity_cap)
    jm['linvel'] = linvel.set(v)
    jm['xyz'] = jm['xyz'].at[indexes].add(v * time_step)


def _propagate_velocities(jm, indexes, mass, time_step):
    linvel = jm['linvel'].at[indexes]
    dcoord = jm["xyz'"][indexes]
    jm['linvel'] = linvel.add(
        time_step * 0.5 * dcoord * _deriv_to_acceleration / mass)


@jax.tree_util.register_dataclass
@dataclass
class _MolecularDynamics:
    """Track the state of a MolecularDynamics optimization using JAX"""

    # Current JAX Model
    jm: dict
    # Number of steps taken
    steps: int
    # JAX random number key
    rkey: jax.Array
    # Any persistent state used by OptimizerStates. Each OptimizerState's
    # _get_jax() method is given a unique index into this list.
    optimizer_states: list
    # Indexes of all particles subject to MD
    simulation_indexes: jax.Array
    # Number of degrees of freedom in the system
    degrees_of_freedom: int
    # Time between integrator steps
    time_step: float

    def get_kinetic_energy(self):
        """Return the current kinetic energy of the system, in kcal/mol"""
        # Conversion factor to get energy in kcal/mol from velocities
        # in A/fs and mass in g/mol
        conversion = 1.0 / 4.1868e-4

        indexes = self.simulation_indexes
        velocity = self.jm['linvel'][indexes]
        mass = self.jm['mass'][indexes]
        return 0.5 * conversion * jnp.sum(
            mass * jnp.sum(jnp.square(velocity), axis=1))

    def get_kinetic_temperature(self, ekinetic):
        """Return the current kinetic temperature of the system"""
        # E = (n/2)kT  n=degrees of freedom, k = Boltzmann constant
        # Boltzmann constant, in kcal/mol
        boltzmann = 8.31441 / 4186.8
        return 2.0 * ekinetic / (self.degrees_of_freedom * boltzmann)


class _MDJAXInfo(IMP._jax_util.JAXOptimizerInfo):
    def __init__(self, md):
        super().__init__(md)
        # score_func returns both score and a modified JAX Model, but
        # deriv_func only wants the first scalar argument (the score)
        deriv_func = jax.grad(lambda jm: self.score_func(jm)[0])
        velocity_cap = md.get_velocity_cap()
        # Would like to use math.isfinite here but it is not guaranteed
        # that a C++ "infinite" value is also considered to be math.inf
        if velocity_cap < 1e20:
            velocity_cap = jnp.array([velocity_cap] * 3)
        else:
            velocity_cap = None
        jax_optstates = self._setup_jax_optimizer_states()

        def init_func(jm, key):
            jm["xyz'"] = deriv_func(jm)["xyz"]
            s = _MolecularDynamics(
                jm=jm, steps=0, optimizer_states=[None] * len(jax_optstates),
                simulation_indexes=md.get_simulation_particle_indexes(),
                degrees_of_freedom=md.get_degrees_of_freedom(),
                rkey=key, time_step=md.get_maximum_time_step())
            for js in jax_optstates:
                s = js.init_func(s)
            return s

        def apply_func(ms):
            jm = ms.jm
            indexes = ms.simulation_indexes
            ms.steps += 1
            mass = jm['mass'][indexes]
            # Make mass 2D so propagate functions can broadcast it over
            # the 2D coordinate/velocity arrays
            mass = mass.reshape(mass.shape[0], 1)
            # Get coordinates at t+(delta t) and velocities at t+(delta t/2)
            _propagate_coordinates(jm, indexes, mass, ms.time_step,
                                   velocity_cap)
            # Get new derivatives at t+(delta t)
            jm["xyz'"] = deriv_func(jm)["xyz"]
            # Get velocities at t+(delta t)
            _propagate_velocities(jm, indexes, mass, ms.time_step)
            steps = ms.steps
            for js in jax_optstates:
                ms = jax.lax.cond(steps % js.period == 0, js.apply_func,
                                  lambda x: x, ms)
            return ms

        self.init_func = init_func
        self.apply_func = apply_func

        # Force MolecularDynamics to create linvel for all particles
        _ = md.get_simulation_particle_indexes()

    def get_jax_model(self):
        jm = super().get_jax_model()
        m = self._opt.get_model()
        jm['mass'] = m.get_floats_numpy(IMP.atom.Mass.get_mass_key())
        jm['linvel'] = jax.numpy.array(
            m.get_vector3ds_numpy(IMP.atom.LinearVelocity.get_velocity_key()))
        jm['xyz'] = jax.numpy.array(jm['xyz'])
        return jm


def _md_optimize(md, max_steps):
    from IMP.core._jax_util import _JAXOptimizer

    jopt = _JAXOptimizer(md, max_steps)
    inner_steps = jopt.inner_steps
    ji = md._get_jax()
    init_func = jax.jit(ji.init_func)
    score_func = jax.jit(ji.score_func)
    apply_func = jax.jit(
        lambda jm: jax.lax.fori_loop(0, inner_steps,
                                     lambda i, jm: ji.apply_func(jm), jm))

    md_state = init_func(ji.get_jax_model(),
                         key=IMP._jax_util.get_random_key())
    m = md.get_model()
    linvel = m.get_vector3ds_numpy(IMP.atom.LinearVelocity.get_velocity_key())
    xyz = m.get_spheres_numpy()[0]
    dxyz = m.get_sphere_derivatives_numpy()[0]

    for _ in jopt.loop():
        md_state = apply_func(md_state)
        # Resync IMP Model arrays with JAX
        jm = md_state.jm
        linvel[:] = jm['linvel']
        xyz[:] = jm['xyz']
        dxyz[:] = jm["xyz'"]
    score, md_state.jm = score_func(md_state.jm)
    return score
