Heatmaps are a powerful way to visualize data in a matrix form, where values are represented by different colors. They are particularly useful for showing correlations, frequencies, and distributions in data sets. This tutorial will guide you through various ways to create and customize heatmaps using the `matplotlib` library. ### Basic Heatmap with `imshow` The simplest way to create a heatmap in Matplotlib is by using the `imshow` function.
import matplotlib.pyplot as plt import numpy as np # Sample data data = np.random.rand(10, 10) plt.imshow(data, cmap='viridis') plt.colorbar() plt.title('Basic Heatmap with imshow') plt.show()
- **`imshow(data, cmap='viridis')`** displays the data as an image, where `data` is a 2D array. - **`colorbar()`** adds a color bar to indicate the scale of values. ### Customizing Heatmap Colors You can customize the color map and other properties to enhance the readability of the heatmap.
import matplotlib.pyplot as plt import numpy as np # Sample data data = np.random.rand(10, 10) plt.imshow(data, cmap='hot', interpolation='nearest') plt.colorbar() plt.title('Custom Heatmap Colors') plt.show()
- **`cmap='hot'`** uses the 'hot' colormap. - **`interpolation='nearest'`** displays the data without any interpolation (each pixel represents a data point). ### Heatmap with Annotations Adding annotations can make your heatmap more informative by displaying the exact values within each cell.
import matplotlib.pyplot as plt import numpy as np # Sample data data = np.random.rand(10, 10) fig, ax = plt.subplots() cax = ax.imshow(data, cmap='viridis') # Add color bar fig.colorbar(cax) # Annotate the heatmap for i in range(data.shape[0]): for j in range(data.shape[1]): ax.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center', color='w') plt.title('Heatmap with Annotations') plt.show()
- **`ax.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center', color='w')`** adds text annotations to each cell, displaying the value formatted to 2 decimal places. ### Using `pcolormesh` for Heatmap The `pcolormesh` function provides a more flexible way to create heatmaps, especially when dealing with non-square grids.
import matplotlib.pyplot as plt import numpy as np # Sample data x = np.arange(0, 11, 1) y = np.arange(0, 11, 1) X, Y = np.meshgrid(x, y) Z = np.sin(X) * np.cos(Y) plt.pcolormesh(X, Y, Z, shading='auto', cmap='coolwarm') plt.colorbar() plt.title('Heatmap with pcolormesh') plt.show()
- **`pcolormesh(X, Y, Z, shading='auto', cmap='coolwarm')`** creates the heatmap using `X` and `Y` as grid coordinates and `Z` as the data values. ### Creating Heatmaps with `matshow` Another convenient method for creating heatmaps is using the `matshow` function, which is a specialized version of `imshow` designed for matrix data.
import matplotlib.pyplot as plt import numpy as np # Sample data data = np.random.rand(10, 10) plt.matshow(data, cmap='viridis') plt.colorbar() plt.title('Heatmap with matshow') plt.show()
- **`matshow(data, cmap='viridis')`** displays the data as a matrix and adds a grid-like appearance. ### Highlighting Specific Regions You can emphasize specific regions in your heatmap by adding contour lines.
import matplotlib.pyplot as plt import numpy as np # Sample data x = np.arange(0, 11, 1) y = np.arange(0, 11, 1) X, Y = np.meshgrid(x, y) Z = np.sin(X) * np.cos(Y) plt.pcolormesh(X, Y, Z, shading='auto', cmap='coolwarm') plt.colorbar() # Adding contours plt.contour(X, Y, Z, colors='black', linewidths=0.5) plt.title('Heatmap with Contours') plt.show()
- **`contour(X, Y, Z, colors='black', linewidths=0.5)`** adds contour lines to the heatmap, highlighting specific regions. ### Heatmap with Custom Ticks and Labels Customizing the ticks and labels can make your heatmap more readable and informative.
import matplotlib.pyplot as plt import numpy as np # Sample data data = np.random.rand(10, 10) labels = [f"{chr(65+i)}" for i in range(10)] fig, ax = plt.subplots() cax = ax.imshow(data, cmap='viridis', interpolation='nearest') # Add color bar fig.colorbar(cax) # Customizing ticks and labels ax.set_xticks(np.arange(len(labels))) ax.set_yticks(np.arange(len(labels))) ax.set_xticklabels(labels) ax.set_yticklabels(labels) plt.title('Heatmap with Custom Ticks and Labels') plt.show()
- **`ax.set_xticks`** and **`ax.set_yticks`** set the tick positions. - **`ax.set_xticklabels`** and **`ax.set_yticklabels`** set the tick labels. ### Example: Correlation Heatmap A common use case for heatmaps is to visualize the correlation matrix of a dataset.
import matplotlib.pyplot as plt import numpy as np import pandas as pd # Sample data data = { 'A': np.random.rand(10), 'B': np.random.rand(10), 'C': np.random.rand(10), 'D': np.random.rand(10) } df = pd.DataFrame(data) # Compute correlation matrix corr = df.corr() # Create a figure and axis fig, ax = plt.subplots(figsize=(8, 6)) # Create a heatmap using matshow cax = ax.matshow(corr, cmap='coolwarm') # Add annotations to each cell for (i, j), value in np.ndenumerate(corr.values): ax.text(j, i, f'{value:.2f}', ha='center', va='center', color='black') # Add a colorbar fig.colorbar(cax) # Set axis labels ax.set_xticks(np.arange(len(df.columns))) ax.set_yticks(np.arange(len(df.columns))) ax.set_xticklabels(df.columns) ax.set_yticklabels(df.columns) plt.title('Correlation Heatmap') plt.show()