{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# JIT and parallelisation\n",
    "\n",
    "This notebook contains some skeletal information about JIT and parallelisation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt \n",
    "import numpy as np \n",
    "import alpro \n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Just In Time\n",
    "alpro uses [just-in-time (JIT) compilation from the numba library](https://numba.readthedocs.io/en/stable/user/jit.html). This basically involves putting a lot of `@jit` decorators around the python functions that do the matrix manipulations. \n",
    "\n",
    "Here's an example of how this works with a for loop. Here we make two identical functions but give one a `@jit` decorator. We call both versions. We see that the first jit call is actually slower than the pure python (because the compilation is deferred until the first call), but for subsequent calls the numba optimised function is way quicker. Just like magic. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time, pure python: 0.004478931427001953\n",
      "time, jit (call 1): 0.09760212898254395\n",
      "time, jit (call 2): 5.626678466796875e-05\n"
     ]
    }
   ],
   "source": [
    "from numba import jit \n",
    "\n",
    "@jit(nopython=True)\n",
    "def sum_with_jit(N):\n",
    "    my_sum = 0\n",
    "    for i in range(N):\n",
    "        my_sum += 1\n",
    "        \n",
    "def sum_without(N):\n",
    "    my_sum = 0\n",
    "    for i in range(N):\n",
    "        my_sum += 1\n",
    "\n",
    "t1 = time.time()\n",
    "x = sum_without(100000)\n",
    "print (\"time, pure python:\", time.time() - t1)        \n",
    "\n",
    "t1 = time.time()\n",
    "x = sum_with_jit(100000)\n",
    "print (\"time, jit (call 1):\", time.time() - t1)  \n",
    "\n",
    "t1 = time.time()\n",
    "x = sum_with_jit(100000)\n",
    "print (\"time, jit (call 2):\", time.time() - t1)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Exactly the same thing happens with alpro. If we only do one calculation, the calculation is quite slow, but subsequent calculations are nice and fast.\n",
    "\n",
    "So, be aware of this when you are writing code to use alpro! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time: 2.9819839000701904\n",
      "time: 0.11284494400024414\n"
     ]
    }
   ],
   "source": [
    "import alpro\n",
    "s = alpro.Survival(\"1275b\")\n",
    "s.init_model()\n",
    "s.set_params(1e-12 * 1e-9, 1e-13)\n",
    "s.domain.create_box_array(1800, 0, s.coherence_func, r0=0)\n",
    "energies = np.logspace(3,4,1000)\n",
    "t1 = time.time()\n",
    "\n",
    "for n in range(2):\n",
    "    t1 = time.time()\n",
    "    P, Pradial = s.propagate(s.domain, energies, pol=\"both\")\n",
    "    print (\"time:\", time.time() - t1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parallelisation\n",
    "If you have installed alpro with ``mpi4py`` installed then you will have access to a few routines to. The function in question is ``parallel.run``, which takes the following arguments:\n",
    "\n",
    "        function    callable\n",
    "                    function to compute\n",
    "\n",
    "        iterable    iterable\n",
    "                    iterable containing first arguments\n",
    "\n",
    "        kwargs      dictionary\n",
    "                    dictionary of keyword arguments to pass to function\n",
    "\n",
    "        split       str\n",
    "                    how to split up the parallel processes\n",
    "                    \n",
    "The function must take as its first argument a variable that is consistent with the type of iterable. For example, if your function takes a string as an argument, then an appropriate iterable would be a list of strings. In alpro, I normally use it for iterating over the RNG seed. An example is shown below, which also shows you how to use the kwargs parameter. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "This is thread 0 calculating models 0 to 9, total 10\n",
      "Waiting for other threads to finish...\n",
      "Thread 0 took 1.1756579875946045 seconds to calculate 10 models\n"
     ]
    }
   ],
   "source": [
    "def my_calculation(i, g = 1e-12 * 1e-9):\n",
    "    s = alpro.Survival(\"1275b\")\n",
    "    s.init_model()\n",
    "    s.set_params(g, 1e-13)\n",
    "    s.domain.create_box_array(1800, 0, s.coherence_func, r0=0)\n",
    "    s.propagate(s.domain, energies, pol=\"both\")\n",
    "\n",
    "import alpro.parallel as parallel \n",
    "iterable = range(0,10)\n",
    "kwargs = {\"g\": 1e-1 * 1e-9}\n",
    "time_taken = parallel.run(my_calculation, iterable, kwargs = kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is currently executing in serial, but to run in parallel you could save the above code as ``script.py`` and run with, e.g.\n",
    "\n",
    "    mpirun -n 4 python script.py"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
