mirror of
https://github.com/ImperialCollegeLondon/ReCoDE_MCMCFF.git
synced 2025-06-26 08:51:16 +02:00
423 lines
44 KiB
Plaintext
423 lines
44 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5ac56056-ca33-4f13-8e36-564b94144c1e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"<h1 align=\"center\">Markov Chain Monte Carlo for fun and profit</h1>\n",
|
|
"<h1 align=\"center\"> 🎲 ⛓️ 👉 🧪 </h1>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "eb5d773e-4cc0-48ae-bb71-7ece7ab5f936",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from numba import jit\n",
|
|
"\n",
|
|
"# This loads some custom styles for matplotlib\n",
|
|
"import json, matplotlib\n",
|
|
"\n",
|
|
"with open(\"assets/matplotlibrc.json\") as f:\n",
|
|
" matplotlib.rcParams.update(json.load(f))\n",
|
|
"\n",
|
|
"np.random.seed(\n",
|
|
" 42\n",
|
|
") # This makes our random numbers reproducable when the notebook is rerun in order\n",
|
|
"\n",
|
|
"\n",
|
|
"def show_state(state, ax=None):\n",
|
|
" if ax is None:\n",
|
|
" f, ax = plt.subplots()\n",
|
|
" ax.matshow(state, cmap=\"Greys\", vmin=-1, vmax=1)\n",
|
|
" ax.set(xticks=[], yticks=[])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "337f1de8-d743-441f-bc15-387bcfff558d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"# Doing Monte Carlo!\n",
|
|
"\n",
|
|
"Now that we can evaluate the energy of a state there isn't that much more work to do Markov Chain Monte Carlo on it. I won't go into the details of how MCMC works but put very simply:\n",
|
|
"\n",
|
|
"We want to calculate thermal averages about a physical system. For example, is this bag of H20 molecules solid or liquid at T = -20C? Our Ising model is much simpler so the equivalent question would be what's the average color of this system at some T?\n",
|
|
"\n",
|
|
"It turns out that this question is pretty hard to answer using maths, it can be done for the 2D Ising model but for anything more complicated it's pretty much impossible. This is where MCMC comes in. MCMC is a numerical method that gives us a rule to probalistically jump from one state of the system to another. \n",
|
|
"\n",
|
|
"If we perform many such jumps many times we get a (Markov) chain of states. The great thing about this chain is that if we average a measurement over it, such as looking at the average proportion of white pixels, the answer we get will be close to the real answer for this system and will converge closer and closer to the true answer as we extend the chain. \n",
|
|
"\n",
|
|
"I've written a very basic MCMC sampler for the 2D Ising model below. It needs:\n",
|
|
"- an initial start to start the chain\n",
|
|
"- to know how many steps to take\n",
|
|
"- the temperature we want to simulate at\n",
|
|
"- a way to measure the energy of a state, which we wrote in a previous chapter\n",
|
|
"\n",
|
|
"It then loops over:\n",
|
|
"- modify the state a little, here we just flip one bond\n",
|
|
"- accepting $p = \\exp(-\\Delta_E / T)$ based on how the energy changed\n",
|
|
"- if we rejected, change the state back to how it was"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "2586a542-35f2-419e-9aa2-2bb9e9ab74b9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 2200x550 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"@jit(nopython=True, nogil=True)\n",
|
|
"def energy(state):\n",
|
|
" E = 0\n",
|
|
" N, M = state.shape\n",
|
|
" for i in range(N):\n",
|
|
" for j in range(M):\n",
|
|
" # handle the north and south neighbours\n",
|
|
" if 0 <= (i + 1) < N:\n",
|
|
" E -= state[i, j] * state[i + 1, j]\n",
|
|
"\n",
|
|
" # handle the east and west neighbours\n",
|
|
" if 0 <= (j + 1) < M:\n",
|
|
" E -= state[i, j] * state[i, j + 1]\n",
|
|
"\n",
|
|
" return 2 * E / (N * M)\n",
|
|
"\n",
|
|
"\n",
|
|
"# While writing numba it's useful to keep the list of supported numpy functions open:\n",
|
|
"# https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html\n",
|
|
"@jit(nopython=True, nogil=True)\n",
|
|
"def mcmc(initial_state, steps, T, energy=energy):\n",
|
|
" N, M = initial_state.shape\n",
|
|
" assert N == M\n",
|
|
"\n",
|
|
" current_state = initial_state.copy()\n",
|
|
" E = N**2 * energy(current_state)\n",
|
|
" for i in range(steps):\n",
|
|
" i, j = np.random.randint(N), np.random.randint(N)\n",
|
|
"\n",
|
|
" # modify the state a little, here we just flip a random pixel\n",
|
|
" current_state[i, j] *= -1\n",
|
|
" new_E = N**2 * energy(current_state)\n",
|
|
"\n",
|
|
" if (new_E < E) or np.exp(-(new_E - E) / T) > np.random.random():\n",
|
|
" E = new_E\n",
|
|
" else:\n",
|
|
" current_state[i, j] *= -1 # reject the change we made\n",
|
|
"\n",
|
|
" return current_state\n",
|
|
"\n",
|
|
"\n",
|
|
"Ts = [4, 5, 50]\n",
|
|
"\n",
|
|
"ncols = 1 + len(Ts)\n",
|
|
"f, axes = plt.subplots(ncols=ncols, figsize=(5 * ncols, 5))\n",
|
|
"\n",
|
|
"initial_state = np.ones(shape=(50, 50))\n",
|
|
"axes[0].set(title=\"Initial state\")\n",
|
|
"show_state(initial_state, ax=ax)\n",
|
|
"\n",
|
|
"for T, ax in zip(Ts, axes[1:]):\n",
|
|
" # initial_state = rng.choice([1,-1], size = (50,50))\n",
|
|
"\n",
|
|
" final_state = mcmc(initial_state, steps=100_000, T=T)\n",
|
|
" show_state(final_state, ax=ax)\n",
|
|
" ax.set(title=f\"T = {T}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5d1874d4-4585-49ed-bc6f-b11c22231669",
|
|
"metadata": {},
|
|
"source": [
|
|
"These images give a flavour of why physicists find this model useful, it gives window into how thermal noise and spontaneous order interact. At low temperatures the energy cost of being different from your neighbours is the most important thing, while at high temperatures, it doesn't matter and you really just do your own thing.\n",
|
|
"\n",
|
|
"There's a special point somewhere in the middle called the critical point $T_c$ where all sorts of cool things happen, but my favourite is that for large system sizes you get a kind of fractal behaviour which I will demonstrate more once we've sped this code up and can simulate larger systems in a reasonable time. You can kinda see it for 50x50 systesm at T = 5 but not really clearly."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5f728039-a975-4083-b68e-a13b4f2d1f87",
|
|
"metadata": {},
|
|
"source": [
|
|
"The code we have so far is really just a sketch of a solution. So this is a good time to step back and think about what are aims are and how this software will fulfil them. I see three broad areas on which it needs improvement:\n",
|
|
"\n",
|
|
"**Functionality**\n",
|
|
"Right now we can't really do much except print nice pictures of states, but (within the fiction of this project) we really want to be able to do science! So we need to think about what measurements and observations we might want to make and how that might affect the structure of our code.\n",
|
|
"\n",
|
|
"**Testing**\n",
|
|
"I've already missed at least one devastating bug in this code, and there are almost certainly more! Before we start adding too much new code we should think about how to increase our confidence that the individual components are working correctly. It's very easy to build a huge project out of hundreds of functions, realise there's a bug and then struggle to find the source of that bug. If we test our components individually and thoroughly, we can avoid some of that pain.\n",
|
|
"\n",
|
|
"**Performance**\n",
|
|
"Performance only matters in so far as it limits what we can do. And there is a real danger that trying to optimise for performance too early or in the wrong places will just lead to complexity that makes the code harder to read, harder to write and more likely to contain bugs. However I do want to show you the fractal states at the critical point, and I can't currently generate those images in a reasonable time, so some optimisation will happen!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "486f066c-f027-44e8-8937-8636a52f32fb",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Functionality\n",
|
|
"\n",
|
|
"The main thing we want to be able to do is to take measurements, the code as I have writting it doesn't really allow that because it only returns the final state in the chain. Let's say we have a measurement called `average_color(state)` that we want to average over the whole chain. We could just stick that inside our definition of `mcmc` but we know that we will likely make other measurements too and we don't want to keep writing new versions of our core functionality!\n",
|
|
"\n",
|
|
"## Exercise 1\n",
|
|
"Have a think about how you would implement this and what options you have."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c28b0a86-28f8-426f-9013-70e962f02256",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Solution 1\n",
|
|
"So I chatted with my mentors on this project on how to best do this and we came up with a few ideas:\n",
|
|
"\n",
|
|
"### Just save all the states and return them\n",
|
|
"\n",
|
|
"The problem with this is the states are very big and we don't want to waste all that memory. For an NxN state that uses 8 bit integers (the smallest we can use in numpy) 1000 samples would already use 2.5Gb of memory! We will see later that we'd really like to be able to go a bit bigger than 50x50 and 1000 samples!\n",
|
|
"\n",
|
|
"### Pass in a function to make measurements\n",
|
|
"```python\n",
|
|
"\n",
|
|
"def mcmc(initial_state, steps, T, measurement, energy=energy):\n",
|
|
" ...\n",
|
|
"\n",
|
|
" current_state = initial_state.copy()\n",
|
|
" E = N**2 * energy(current_state)\n",
|
|
" for i in range(steps):\n",
|
|
" measurements[i] = measurement(state)\n",
|
|
" ...\n",
|
|
"\n",
|
|
" return measurements\n",
|
|
"```\n",
|
|
"\n",
|
|
"This could work but it limits how we can store measurements and what shape and type they can be. What if we want to store our measurements in a numpy array? Or what if your measurement itself is a vector or and object that can't easily be stored in a numpy array? We would have to think carefully about what functionality we want."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c7c9575f-2450-4298-a507-90f0c1b9b284",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"### Use Inheritance\n",
|
|
"```python\n",
|
|
"# This class would define the basic functionality of performing MCMC\n",
|
|
"class MCMCSampler(object):\n",
|
|
" def run(self, initial_state, steps, T):\n",
|
|
" ...\n",
|
|
" for i in range(steps):\n",
|
|
" self.measurement(state)\n",
|
|
"\n",
|
|
" \n",
|
|
"# This class would inherit from it and just implement the measurement\n",
|
|
"class AverageColorSampler(MCMCSampler):\n",
|
|
" measurements = np.zeros(10)\n",
|
|
" index = 0\n",
|
|
" \n",
|
|
" def measurement(self, state):\n",
|
|
" self.measurements[self.index] = some_function(state)\n",
|
|
" self.index += 1\n",
|
|
" \n",
|
|
"color_sampler = AverageColorSampler(...)\n",
|
|
"measurements = color_sampler.run(...)\n",
|
|
"```\n",
|
|
"\n",
|
|
"This would definitely work but I personally am not a huge fan of object oriented programming so I'm gonna skip this option!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7d05d25d-c9ba-406d-9977-0ca4aeb430a7",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Use a generator\n",
|
|
"This is the approach I ended up settling on, we will use [python generator function](https://peps.python.org/pep-0255/) "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"id": "f73d6335-6514-45b1-9128-d72122d8b0b7",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[[-9.86830992e+148, -9.86830992e+148, -9.86830992e+148, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" ...,\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000]],\n",
|
|
"\n",
|
|
" [[-9.86830992e+148, -9.86830992e+148, -9.86830992e+148, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" ...,\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000]],\n",
|
|
"\n",
|
|
" [[-9.86830992e+148, -9.86830992e+148, -9.86830992e+148, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" ...,\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000]],\n",
|
|
"\n",
|
|
" ...,\n",
|
|
"\n",
|
|
" [[-9.86830992e+148, -9.86830992e+148, -9.86830992e+148, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" ...,\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000]],\n",
|
|
"\n",
|
|
" [[-9.86830992e+148, -9.86830992e+148, -9.86830992e+148, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" ...,\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000]],\n",
|
|
"\n",
|
|
" [[-9.86830992e+148, -9.86830992e+148, -9.86830992e+148, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" ...,\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n",
|
|
" [ 1.00000000e+000, 1.00000000e+000, 1.00000000e+000, ...,\n",
|
|
" 1.00000000e+000, 1.00000000e+000, 1.00000000e+000]]])"
|
|
]
|
|
},
|
|
"execution_count": 44,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"@jit(nopython=True, nogil=True)\n",
|
|
"def mcmc(initial_state, steps, T, energy=energy):\n",
|
|
" N, M = initial_state.shape\n",
|
|
" assert N == M\n",
|
|
"\n",
|
|
" current_state = initial_state.copy()\n",
|
|
" E = N**2 * energy(current_state)\n",
|
|
" for i in range(steps):\n",
|
|
" i, j = np.random.randint(N), np.random.randint(N)\n",
|
|
"\n",
|
|
" # modify the state a little, here we just flip a random pixel\n",
|
|
" current_state[i, j] *= -1\n",
|
|
" new_E = N**2 * energy(current_state)\n",
|
|
"\n",
|
|
" if (new_E < E) or np.exp(-(new_E - E) / T) > np.random.random():\n",
|
|
" E = new_E\n",
|
|
" else:\n",
|
|
" current_state[i, j] *= -1 # reject the change we made\n",
|
|
"\n",
|
|
" yield current_state # give the state out to the enclosing function but don't actually return\n",
|
|
"\n",
|
|
" return # this signals that we're done\n",
|
|
"\n",
|
|
"\n",
|
|
"initial_state = np.ones(shape=(50, 50))\n",
|
|
"np.array([s for s in mcmc(initial_state, steps=10, T=5)])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "193b778f-5913-48f1-9df6-304ab50ceb4e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python [conda env:recode]",
|
|
"language": "python",
|
|
"name": "conda-env-recode-py"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|