{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5ac56056-ca33-4f13-8e36-564b94144c1e",
   "metadata": {
    "tags": []
   },
   "source": [
    "<h1 align=\"center\">Markov Chain Monte Carlo for fun and profit</h1>\n",
    "<h1 align=\"center\"> 🎲 ⛓️ 👉 🧪 </h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eb5d773e-4cc0-48ae-bb71-7ece7ab5f936",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from numba import jit\n",
    "\n",
    "# This loads some custom styles for matplotlib\n",
    "import json, matplotlib\n",
    "\n",
    "with open(\"../_static/matplotlibrc.json\") as f:\n",
    "    matplotlib.rcParams.update(json.load(f))\n",
    "\n",
    "np.random.seed(\n",
    "    42\n",
    ")  # This makes our random numbers reproducible when the notebook is rerun in order"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "486f066c-f027-44e8-8937-8636a52f32fb",
   "metadata": {},
   "source": [
    "# Adding Functionality\n",
    "\n",
    "The main thing we want to be able to do is to take measurements, the code as I have writing it doesn't really allow that because it only returns the final state in the chain. Let's say we have a measurement called `average_color(state)` that we want to average over the whole chain. We could just stick that inside our definition of `mcmc`, but we know that we will likely make other measurements too, and we don't want to keep writing new versions of our core functionality!\n",
    "\n",
    "## Exercise 1\n",
    "Have a think about how you would implement this and what options you have."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c28b0a86-28f8-426f-9013-70e962f02256",
   "metadata": {},
   "source": [
    "## Solution 1\n",
    "So I chatted with my mentors on this project on how to best do this, and we came up with a few ideas:\n",
    "\n",
    "### Option 1: Just save all the states and return them\n",
    "\n",
    "The problem with this is the states are very big, and we don't want to waste all that memory. For an `NxN` state that uses 8-bit integers (the smallest we can use in numpy) `1000` samples would already use `2.5Gb` of memory! We will see later that we'd really like to be able to go a bit bigger than `50x50` and `1000` samples!\n",
    "\n",
    "### Option 2: Pass in a function to make measurements\n",
    "```python\n",
    "\n",
    "def mcmc(initial_state, steps, T, measurement, energy=energy):\n",
    "    ...\n",
    "\n",
    "    current_state = initial_state.copy()\n",
    "    E = energy(current_state)\n",
    "    for i in range(steps):\n",
    "        measurements[i] = measurement(state)\n",
    "        ...\n",
    "\n",
    "    return measurements\n",
    "```\n",
    "\n",
    "This could work, but it limits how we can store measurements and what shape and type they can be. What if we want to store our measurements in a numpy array? Or what if your measurement itself is a vector or and object that can't easily be stored in a numpy array? We would have to think carefully about what functionality we want."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7c9575f-2450-4298-a507-90f0c1b9b284",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Option 3: Use Inheritance\n",
    "```python\n",
    "# This class would define the basic functionality of performing MCMC\n",
    "class MCMCSampler(object):\n",
    "    def run(self, initial_state, steps, T):\n",
    "        ...\n",
    "        for i in range(steps):\n",
    "            self.measurement(state)\n",
    "\n",
    "       \n",
    "# This class would inherit from it and just implement the measurement\n",
    "class AverageColorSampler(MCMCSampler):\n",
    "    measurements = np.zeros(10)\n",
    "    index = 0\n",
    "    \n",
    "    def measurement(self, state):\n",
    "        self.measurements[self.index] = some_function(state)\n",
    "        self.index += 1\n",
    "        \n",
    "color_sampler = AverageColorSampler(...)\n",
    "measurements = color_sampler.run(...)\n",
    "```\n",
    "\n",
    "This would definitely work, but I personally am not a huge fan of object-oriented programming, so I'm going to skip this option!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d05d25d-c9ba-406d-9977-0ca4aeb430a7",
   "metadata": {},
   "source": [
    "## Option 4: Use a generator\n",
    "This is the approach I ended up settling on, we will use [python generator function](https://peps.python.org/pep-0255/). While you may not have come across generator functions before, you almost certainly will have come across generators, `range(n)` is a generator, `(i for i in [1,2,3])` is a generator. Generator functions are a way to build your own generators, by way of example here is range implemented as a generator function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b17c054-230f-4188-98a6-51fd8fe5b437",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<generator object my_range at 0x7fcaff25da50>, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def my_range(n):\n",
    "    \"Behaves like the builtin range function of one argument\"\n",
    "    i = 0\n",
    "    while i < n:\n",
    "        yield i  # sends i out to whatever function called us\n",
    "        i += 1\n",
    "    return  # let's python know that we have nothing else to give\n",
    "\n",
    "\n",
    "my_range(10), list(my_range(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b74fadbe-80c2-4a20-b651-0e47188b005a",
   "metadata": {},
   "source": [
    "This requires only a very small change to our `mcmc` function, and suddenly we can do whatever we like with the states! While we're at it, I'm going to add an argument `stepsize` that allows us to only sample the state every `stepsize` MCMC steps. You'll see why we would want to set this to value greater than 1 in a moment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f73d6335-6514-45b1-9128-d72122d8b0b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18.6 ms ± 645 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "13.3 ms ± 1.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "1x slowdown!\n"
     ]
    }
   ],
   "source": [
    "from MCFF.ising_model import energy, all_up_state\n",
    "from MCFF.mcmc import mcmc_original\n",
    "\n",
    "\n",
    "@jit(nopython=True, nogil=True)\n",
    "def mcmc_generator(initial_state, steps, T, stepsize=1000, energy=energy):\n",
    "    N, M = initial_state.shape\n",
    "    assert N == M\n",
    "\n",
    "    current_state = initial_state.copy()\n",
    "    E = energy(current_state)\n",
    "    for _ in range(steps):\n",
    "        for _ in range(stepsize):\n",
    "            i, j = np.random.randint(N), np.random.randint(N)\n",
    "\n",
    "            # modify the state a little, here we just flip a random pixel\n",
    "            current_state[i, j] *= -1\n",
    "            new_E = energy(current_state)\n",
    "\n",
    "            if (new_E < E) or np.exp(-(new_E - E) / T) > np.random.random():\n",
    "                E = new_E\n",
    "            else:\n",
    "                current_state[i, j] *= -1  # reject the change we made\n",
    "        yield current_state.copy()\n",
    "    return\n",
    "\n",
    "\n",
    "N_steps = 1000\n",
    "stepsize = 1\n",
    "initial_state = all_up_state(20)\n",
    "without_yield = %timeit -o mcmc_original(initial_state, steps = N_steps, T = 5)\n",
    "with_yield = %timeit -o [np.mean(s) for s in mcmc_generator(initial_state, T = 5, steps = N_steps, stepsize = 1)]\n",
    "print(f\"{with_yield.best / without_yield.best:.0f}x slowdown!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "132bf8d1-d341-494a-9881-689605a104b7",
   "metadata": {},
   "source": [
    "Fun fact: if you replace `yield current_state.copy()` with `yield current_state` your python kernel will crash when you run the code. I believe this is a bug in Numba that related to how pointers to numpy arrays work but let's not worry too much about it. \n",
    "\n",
    "We take a factor of two slowdown, but that doesn't seem so much to pay for the fact we can now sample the state at every single step rather than just the last."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b4e3c207-bc11-483a-aaf6-2daebb3194d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Author: Thomas Hodson\n",
      "\n",
      "Github username: T_Hodson\n",
      "\n",
      "Last updated: Mon Jul 18 2022\n",
      "\n",
      "Python implementation: CPython\n",
      "Python version       : 3.9.12\n",
      "IPython version      : 8.4.0\n",
      "\n",
      "Git hash: 03657e08835fdf23a808f59baa6c6a9ad684ee55\n",
      "\n",
      "Git repo: https://github.com/ImperialCollegeLondon/ReCoDE_MCMCFF.git\n",
      "\n",
      "Git branch: main\n",
      "\n",
      "json      : 2.0.9\n",
      "numpy     : 1.21.5\n",
      "matplotlib: 3.5.1\n",
      "\n",
      "Watermark: 2.3.1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -n -u -v -iv -w -g -r -b -a \"Thomas Hodson\" -gu \"T_Hodson\""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:recode]",
   "language": "python",
   "name": "conda-env-recode-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}