import IMP
import IMP.test
import IMP.algebra
import IMP.core
import IMP.example
import pickle
try:
    import jax
except ImportError:
    jax = None


def make_modifier():
    m = IMP.Model()
    bb = IMP.algebra.BoundingBox3D(IMP.algebra.Vector3D(0, 0, 0),
                                   IMP.algebra.Vector3D(10, 10, 10))
    p = m.add_particle("p1")
    d = IMP.core.XYZ.setup_particle(m, p, IMP.algebra.Vector3D(-4, 13, 28))
    s = IMP.example.ExampleSingletonModifier(bb)
    return m, p, d, s


class Tests(IMP.test.TestCase):

    def test_modifier(self):
        """Test example SingletonModifier"""
        for typ in (IMP.example.ExampleSingletonModifier,
                    IMP.example.PythonExampleSingletonModifier):
            m = IMP.Model()
            bb = IMP.algebra.BoundingBox3D(IMP.algebra.Vector3D(1, 2, 3),
                                           IMP.algebra.Vector3D(10, 10, 10))
            p = m.add_particle("p1")
            d = IMP.core.XYZ.setup_particle(m, p,
                                            IMP.algebra.Vector3D(-4, 13, 28))
            s = typ(bb)
            s.apply_index(m, p)
            self.assertLess(IMP.algebra.get_distance(d.get_coordinates(),
                                           IMP.algebra.Vector3D(5,5,7)), 1e-4)
            self.assertIn("SingletonModifier", str(s))
            self.assertIn("SingletonModifier", repr(s))
            self.assertIn("example", s.get_version_info().get_module())
            self.assertEqual(len(s.get_inputs(m, [p])), 1)
            self.assertEqual(len(s.get_outputs(m, [p])), 1)

    def test_combine(self):
        """Test combining example SingletonModifier with IMP classes"""
        for typ in (IMP.example.ExampleSingletonModifier,
                    IMP.example.PythonExampleSingletonModifier):
            m = IMP.Model()
            bb = IMP.algebra.BoundingBox3D(IMP.algebra.Vector3D(1, 2, 3),
                                           IMP.algebra.Vector3D(10, 10, 10))
            p = m.add_particle("p")
            d = IMP.core.XYZ.setup_particle(m, p,
                                            IMP.algebra.Vector3D(-4, 13, 28))
            c = IMP.core.SingletonConstraint(typ(bb), None, m, p)
            m.add_score_state(c)
            m.update()
            self.assertLess(IMP.algebra.get_distance(d.get_coordinates(),
                                           IMP.algebra.Vector3D(5,5,7)), 1e-4)

    def test_pickle(self):
        """Test (un-)pickle of ExampleSingletonModifier"""
        m, p, d, s = make_modifier()
        dump = pickle.dumps(s)
        news = pickle.loads(dump)
        news.apply_index(m, p)
        self.assertLess(IMP.algebra.get_distance(
            d.get_coordinates(), IMP.algebra.Vector3D(6,3,8)), 1e-4)

    def test_pickle_polymorphic(self):
        """Test (un-)pickle of ExampleSingletonModifier via polymorphic ptr"""
        m, p, d, s = make_modifier()
        c = IMP.core.SingletonConstraint(s, None, m, p)
        dump = pickle.dumps(c)
        newc = pickle.loads(dump)
        newc.before_evaluate()
        self.assertLess(IMP.algebra.get_distance(
            d.get_coordinates(), IMP.algebra.Vector3D(6,3,8)), 1e-4)

    @IMP.test.skipIf(jax is None, "No JAX support")
    def test_jax(self):
        """Test JAX implementation of SingletonModifier"""
        import IMP._jax_util
        import jax.numpy as jnp
        m = IMP.Model()
        bb = IMP.algebra.BoundingBox3D(IMP.algebra.Vector3D(0, 0, 0),
                                       IMP.algebra.Vector3D(10, 10, 10))
        p = m.add_particle("p1")
        d = IMP.core.XYZ.setup_particle(m, p,
                                        IMP.algebra.Vector3D(-4, 13, 28))
        s = IMP.example.ExampleSingletonModifier(bb)
        ji = s._get_jax(m, p)
        X = IMP._jax_util._get_jax_model(m, ji._keys)
        f = jax.jit(ji.apply_func)
        X = f(X, p)
        self.assertLess(jnp.linalg.norm(X['xyz'][0] - jnp.array([6., 3., 8.])),
                        1e-3)

    @IMP.test.skipIf(jax is None, "No JAX support")
    def test_jax_singleton_constraint(self):
        """Test JAX SingletonModifier in a SingletonConstraint"""
        import IMP._jax_util
        import jax.numpy as jnp
        m = IMP.Model()
        bb = IMP.algebra.BoundingBox3D(IMP.algebra.Vector3D(0, 0, 0),
                                       IMP.algebra.Vector3D(10, 10, 10))
        p = m.add_particle("p1")
        d = IMP.core.XYZ.setup_particle(m, p,
                                        IMP.algebra.Vector3D(-4, 13, 28))
        s = IMP.example.ExampleSingletonModifier(bb)
        c = IMP.core.SingletonConstraint(s, None, m, p)

        ji = c._get_jax()
        X = ji.get_jax_model()
        f = jax.jit(ji.apply_func)
        X = f(X)
        self.assertLess(jnp.linalg.norm(X['xyz'][0] - jnp.array([6., 3., 8.])),
                        1e-3)

    @IMP.test.skipIf(jax is None, "No JAX support")
    def test_jax_singletons_constraint(self):
        """Test JAX SingletonModifier in a SingletonsConstraint"""
        import IMP._jax_util
        import jax.numpy as jnp
        m = IMP.Model()
        bb = IMP.algebra.BoundingBox3D(IMP.algebra.Vector3D(0, 0, 0),
                                       IMP.algebra.Vector3D(10, 10, 10))
        p1 = m.add_particle("p1")
        d1 = IMP.core.XYZ.setup_particle(m, p1,
                                         IMP.algebra.Vector3D(-4, 13, 28))
        p2 = m.add_particle("p2")
        d2 = IMP.core.XYZ.setup_particle(m, p2,
                                         IMP.algebra.Vector3D(3, 20, 42))
        s = IMP.example.ExampleSingletonModifier(bb)
        lsc = IMP.container.ListSingletonContainer(m, [p1, p2])
        c = IMP.container.SingletonsConstraint(s, None, lsc)

        ji = c._get_jax()
        X = ji.get_jax_model()
        f = jax.jit(ji.apply_func)
        X = f(X)
        self.assertLess(jnp.linalg.norm(X['xyz'][0] - jnp.array([6., 3., 8.])),
                        1e-3)
        self.assertLess(jnp.linalg.norm(X['xyz'][1] - jnp.array([3., 0., 2.])),
                        1e-3)


if __name__ == '__main__':
    IMP.test.main()
