# Colormaps

Before we start, here are some useful links on some of the topics covered in this post.

Using Linear Segmented Colormaps

Scaling the indices in a cmap by a function

## The basics

The color dictionary in a colormap consists of a set of tuples for ‘red’, ‘green’, ‘blue’, each with three elements.  Each of the tuples has the form (x,y0,y1), which sets the RGB values that will be used for interpolation in the colorbar.  I find it more instructive to think about them as ( cm_index, below_interp_to, above_interp_from ):

•   ‘cm_index’ is the index in the colormap, from 0 to 1, that corresponds to the normalized range of the data values. That is, cm_index=0 in the colormap corresponds to the minimum data value plotted, and cm_index=1 in the colormap corresponds to the maximum data value that is plotted.  If you think about it like a colorbar, cm_index=0 is the bottom color, and cm_index=1 is the top color.
•   ‘below_interp_to’ is the R, G, or B value (from 0 to 1) that is used for interpolation below the ‘cm_index’ value in the tuple
•   ‘above_interp_from’ is the R, G, or B value (from 0 to 1) that is used for interpolation above the ‘cm_index’ value in the tuple

How to create a colormap:

from matplotlib import colors, cm, pyplot as plt

cdict = {
'red'  :  ((0., 0., 0.), (0.5, 0.25, 0.25), (1., 1., 1.)),
'green':  ((0., 1., 1.), (0.7, 0.0, 0.5), (1., 1., 1.)),
'blue' :  ((0., 1., 1.), (0.5, 0.0, 0.0), (1., 1., 1.))
}

#generate the colormap with 1024 interpolated values
my_cmap = colors.LinearSegmentedColormap('my_colormap', cdict,1024)



Accessing the values in an existing colormap.

cmap_old=cm.get_cmap('jet')

cmapvals=cmap_old._segmentdata
#--OR--#

print cmapvals
#  -->
#    {'blue': ((0.0, 0.5, 0.5),
#              (0.11, 1, 1),
#              (0.34, 1, 1),
#              (0.65, 0, 0),
#              (1, 0, 0)),
#     'green': ((0.0, 0, 0),
#              (0.125, 0, 0),
#              (0.375, 1, 1),
#              (0.64, 1, 1),
#              (0.91, 0, 0),
#              (1, 0, 0)),
#     'red': ((0.0, 0, 0),
#             (0.35, 0, 0),
#             (0.66, 1, 1),
#             (0.89, 1, 1),
#             (1, 0.5, 0.5))}


For examples of how this works, let’s look at some possible tuples in ‘blue’: (if this gets confusing, just jump ahead to the examples that follow)

• (0,0,0)  –> This tuple defines the beginning point, cm_index=0 (the bottom of the colorbar, for the lowest data value). Blue starts at B=0 here.
• (.25,0.1,0.1) –> This says that for the colormap value 0.25 (one-quarter of the way up the colorbar), B should be interpolated from the value in the previous tuple (B=0, from the one we talked about above) to B=0.1.  Furthermore, the interpolation value to use between cm_index=0.25 and that of the next tuple (0.5, coming up next) is also B=0.1.
• (0.5,0.8,1) –> This says that for cm_index=0.5 (the half-way point on the colorbar), the blue value should interpolate B values between the previous definition point (B=0.1 at cm_index=0.25 from before) and 0.8.  It also sets the interpolation point at 1.0 between cm_index=.5 and the next point.  The fact that we set different values 0.8 and 1.0 for bwlow_interp_to and above_interp_from means that you will see a break in the colormap at cm_index=0.5.  In other words, the colormap will not have a completely smooth color transition at the half-way point.  We’ll see an example of this later.
• (1,1,1) –> This tuple defines the end point: cm_index=1, the top of the colorbar. Interpolate between B=0.8 and B=1 for points between the last tuple and this one (cm_index=0.5 to 1).

Example: Make a simple colorbar of entirely blue values, starting from black. (Black is R=0,G=0,B=0.  White is R=1,G=1,B=1.  Pure Blue is R=0,G=0,B=1.)
To do this, we can set our ‘blue’ values to start at B=0 at ‘r’ = colorbar range = 0, then interpolate up to B=1 at ‘r’ = 1.  We don’t want red or green in this example, so we will set them both to be R=0 and G=0 at ‘r’=0 as well as at ‘r’=1.

cdict_blue = {
'blue'  : ((0,0,0),  (1,1,1)),
'green' : ((0,0,0),  (1,0,0)),
'red'   : ((0,0,0),  (1,0,0))
}

cmap_blue=colors.LinearSegmentedColormap('cmap_blue',cdict_blue,1024)


Whenever the R,G,B values all equal the same value, you have a shade of gray.  Thus, we can make a simple grayscale with:

cdict_gray = {
'blue'  : ((0,0,0),  (1,1,1)),
'green' : ((0,0,0),  (1,1,1)),
'red'   : ((0,0,0),  (1,1,1))
}

cmap_gray=colors.LinearSegmentedColormap('cmap_gray',cdict_gray,1024)


Now, let’s make a special cmap that uses entirely red values (starting with black at the bottom).  But now, let’s give it a sharp break at 80%, at which point it will transition from blue to pure white.  This is useful if you want to demonstrate that some of your data goes above a threshold (80% here), and still show the variation above that threshold in another color (blue) without simply blanking it at a uniform value.

cdict_redtoblue = {
'blue'  : ((0,0,0), (0.8,0,1), (1,1,1)),
'green' : ((0,0,0), (0.8,0,0), (1,1,1)),
'red'   : ((0,0,0), (0.8,1,0), (1,1,1))
}

cmap_redtoblue=colors.LinearSegmentedColormap('cmap_redtoblue',cdict_redtoblue,1024)


Show Basic Colormaps Plot
#Plotting the example colormaps
import numpy as np
from matplotlib import colors, cm, pyplot as plt

zvals=[np.arange(1,10,.1)]

fig=plt.figure(1)
plt.imshow(zvals,cmap=cmap_blue)
plt.title(r'cmap_blue')
ax1.axes.get_xaxis().set_visible(False); ax1.axes.get_yaxis().set_visible(False)

plt.imshow(zvals,cmap=cmap_gray)
plt.title(r'cmap_gray')
ax2.axes.get_xaxis().set_visible(False); ax2.axes.get_yaxis().set_visible(False)

plt.imshow(zvals,cmap=cmap_redtoblue)
plt.title(r'cmap_redtoblue')
ax3.axes.get_yaxis().set_visible(False)

plt.savefig('cmap_examples.png',bbox_inches='tight')
plt.show()
plt.clf()


If you want to register the new cmap for permanent use on your machine:
cm.register_cmap(name=’my_cmap_name’, cmap=cmap_you_created)

## Scaling your data (rather than scaling the colormap).

It is often easier and more desirable to scale your data before plotting.  The colorbar scale will remain linear.  Simply apply some type of transformation to your data array (such as logdata=numpy.log10(data) ), then plot as you normally would.  Here is an example.

Scaling Data Code
import numpy as np

y,x=np.mgrid[0:1,0:10:0.1]
z=x**(7)
zscale=z**(1/7.)
#For z=x^beta: beta=logz/logx.  Take the last data point for a quick & dirty estimate of beta.
#beta=np.log10(z[-1])/np.log10(x[-1])

plt.rc('text',usetex=True)

fig=plt.figure(1)
plt.imshow(z,cmap='jet')
plt.title('Original Data: z = x$^{\\beta}$')
ax1.axes.get_xaxis().set_visible(False); ax1.axes.get_yaxis().set_visible(False)

plt.imshow(zscale,cmap='jet')
plt.title('Scaled Data: $\\textrm{z}^{1/ \\beta} = {\\left( \\textrm{x}^{\\beta} \\right) }^{1/\\beta}$')
ax2.axes.get_xaxis().set_visible(False); ax2.axes.get_yaxis().set_visible(False)

plt.savefig('cmap_powerscaledata.png',bbox_inches='tight')
plt.show()
plt.clf()


## Scaling the colormap

If you are dealing with log scale, you can simply set the normalization to ‘LogNorm’ within the plot command. For instance:
pyplot.imshow(data,cmap=’jet’,norm=matplotlib.colors.LogNorm(vmin=data.min(), vmax=data.max()) )
See http://matplotlib.org/examples/pylab_examples/pcolor_log.html for an example.

If you want to scale the colormap indices to a function manually, there are a few resources that are already available online.  The scipy documentation has a helpful piece of code to accomplish this.  Below is a slightly modified version of the cmap_xmap function available at http://www.scipy.org/Cookbook/Matplotlib/ColormapTransformations

import matplotlib
def cmap_xmap(function,cmap):
""" Applies function, on the indices of colormap cmap. Beware, function
should map the [0, 1] segment to itself, or you are in for surprises.

"""
cdict = cmap._segmentdata.copy()
function_to_map = (lambda x : (function(x[0]), x[1], x[2]))
for key in ('red','green','blue'):
cdict[key] = map(function_to_map, cdict[key])
cdict[key].sort()
assert (cdict[key][0]<;0 or cdict[key][-1]>;1), "Resulting indices extend out of the [0, 1] segment."
return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024)

#Example usage:
pscale=2.
cmap_old=cm.jet
cmap_new=cmap_xmap(lambda funcvar: funcvar**pscale,cmap_old)


If your function results in cmap indices over 1, there is an easy way to re-normalize to the 0:1 scale.  Use matplotlib.colors.BoundaryNorm() with appropriate boundaries based on the values you got from cmap_xmap

cmap_new=cmap_xmap(lambda funcvar: np.log10(funcvar),cmap_old)
bounds=[0,8,10]
scalednorm = colors.BoundaryNorm(bounds, cmap_new.N)

plt.imshow(...,cmap=...,norm=scalednorm)


Here is an example of colormaps with different power scales, making use of cmap_xmap.

z=[np.arange(1,10,.1)]

cmap_4=cmap_xmap(lambda funcvar: funcvar**4.,cm.spectral)
cmap_1_4=cmap_xmap(lambda funcvar: funcvar**(1/4.),cm.spectral)  #Note the dot after the 4!

plt.rc('font',family='serif')
fig=plt.figure(1)
plt.imshow(z,cmap='spectral')
plt.title("Original cmap 'spectral'")
ax1.axes.get_xaxis().set_visible(False); ax1.axes.get_yaxis().set_visible(False)

plt.imshow(z,cmap=cmap_4)
plt.title("cmap 'spectral' scaled to the power p=4")
ax2.axes.get_xaxis().set_visible(False); ax2.axes.get_yaxis().set_visible(False)

plt.imshow(z,cmap=cmap_1_4)
plt.title("cmap 'spectral' scaled to the power p=1/4")
ax3.axes.get_xaxis().set_visible(False); ax3.axes.get_yaxis().set_visible(False)

plt.savefig('cmap_powerscale.png',bbox_inches='tight')
plt.show()
plt.clf()


## Multiple Scaling

Sometimes normal colormap scaling tricks just don’t work.  It is difficult to highlight variations in faint features and bright features at the same time within a single data set.  In these cases, you would want the colormap to have high range at the low end and again at the high end, with little range in the middle.  I’ve created one simple solution below based on the code for cmap_xmap in the SciPy documentation.  Of course there are many ways to accomplish this (in better ways, I’m quite sure), but this is simply a tutorial.

We are scaling the colormap by two different powers (pscale1 and pscale2), and manually defining the point where we want it to make the transition (breakpoint).  We need to define a function that will scale the cmap indices by one power before the break point and then by the other power after. We must also take care this time to make sure that there is no overlap between the two powers – let’s say we decide to break at colormap index 0.5.  If we choose our scaling powers as 1/2. and 3 (–> x^(1/2.) and x^3), then colormap index of 0.4 would be scaled to 0.6324… (which is greater than the breakpoint of 0.5) while cmap index 0.7 would remap to 0.343 (below the breakpoint).  This would create a mix-up in the colorscale, and the resulting cbar would not follow the same color progression as the original.  One simple fix would be to set those values that would pass the breakpoint to be just the breakpoint.  In other words, 0.6324 –> 0.5 and 0.343 –> 0.5.  Of course this isn’t ideal since it creates a hard break, but it’s a start.

In this simplisitic solution, we make a new function called pscalefn() that takes care of all that.

#Multi-power scaler for colormaps
def cmap_powerscaler(cmap_in,breakpoint=.5,pscale1=2,pscale2=(1/2.)):
#*args are breakpoint, pscale1, pscale2
#defaults are 0.5, 2, 1/2.
try: matplotlib
except: import matplotlib

cdict = cmap_in._segmentdata.copy()
function_to_map = (lambda x : (pscalefn(x[0],breakpoint,pscale1,pscale2), x[1], x[2]))
for key in ('red','green','blue'):
cdict[key] = map(function_to_map, cdict[key])
#cdict[key].sort()
assert (cdict[key][0]<0 or cdict[key][-1]>1), "Resulting indices extend out of the [0, 1] segment."
return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024)

def pscalefn(cval,breakpoint,pscale1,pscale2):
if cval<=breakpoint:
if cval**pscale1<=breakpoint: return cval**pscale1
else: return breakpoint
else:
if cval**pscale2<=breakpoint: return breakpoint
else: return cval**pscale2


Sample usage

z=[np.arange(1,10,.1)]

cmap_3to1_3=cmap_powerscaler(cm.spectral,breakpoint=.5,pscale1=3.,pscale2=(1/3.))
cmap_1_3to3=cmap_powerscaler(cm.spectral,breakpoint=.7,pscale1=(1/3.),pscale2=3.)

Multiple Scales Code
z=[np.arange(1,10,.1)]

cmap_3to1_3=cmap_powerscaler(cm.spectral,breakpoint=.5,pscale1=3.,pscale2=(1/3.)) #Note the dot!
cmap_1_3to3=cmap_powerscaler(cm.spectral,breakpoint=.7,pscale1=(1/3.),pscale2=3.)

plt.rc('font',family='serif')
fig=plt.figure(1)
plt.imshow(z,cmap='spectral')
plt.title("Original cmap 'spectral'",size=12)
ax1.axes.get_xaxis().set_visible(False); ax1.axes.get_yaxis().set_visible(False)

plt.imshow(z,cmap=cmap_3to1_3)
plt.title("Scaled to the power p=3 up to cm_index=0.5, and to p=1/3. after",size=12)
ax2.axes.get_xaxis().set_visible(False); ax2.axes.get_yaxis().set_visible(False)

plt.imshow(z,cmap=cmap_1_3to3)
plt.title("Scaled to the power p=1/3. up to cm_index=0.7, and to p=3 after",size=12)
ax3.axes.get_xaxis().set_visible(False); ax3.axes.get_yaxis().set_visible(False)

plt.savefig('cmap_powerscale_multiscale.png',bbox_inches='tight')
plt.show()
plt.clf()


The most useful way to see the different scales is to apply them to some real data.  Here we will again use the image of M101 as used extensively in the Kapteyn package tutorials.  http://www.astro.rug.nl/software/kapteyn/

The examples below show the following settings:

•      Original matplotlib.cm.spectral colormap with no scaling done (linear)
•      Scaled as x6, but with a break at cmap index 0.2
•      Transition from x3 to x2 with the break at cmap index 0.2
•      Transition from x(1.2) to x(1/1.5) with the break at cmap index 0.5
Galaxy Plots Code
import pyfits
import pywcsgrid2
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

m101dat=pyfits.getdata('m101.fits')

cmap2=cmap_powerscaler(cm.spectral,breakpoint=.2,pscale1=6,pscale2=6)
cmap3=cmap_powerscaler(cm.spectral,breakpoint=.2,pscale1=3,pscale2=2)
cmap4=cmap_powerscaler(cm.spectral,breakpoint=.5,pscale1=1.2,pscale2=(1/1.5)) #Note the dot!

fig=plt.figure(2)
ax1.set_ticklabel_type("delta",nbins=6)
ax1.set_xlabel("$\Delta$RA (J2000)",size=8)
ax1.set_ylabel('$\Delta$DEC (J2000)',size=8)
ax1.axis[:].major_ticklabels.set_fontsize(10)
plt.imshow(m101dat,origin='lower',interpolation='nearest',cmap='spectral')
plt.title("Original cmap 'spectral'\n(linear)",size=8)
cax1 = inset_axes(ax1,
width="5%", # width = 5% of parent_bbox width
height="100%", # height : 100%
loc=2,
bbox_to_anchor=(1.05, 0, 1, 1),
bbox_transform=ax1.transAxes,
)
cbar1=plt.colorbar(cax=cax1,orientation='vertical',cmap='spectral')
cbar1.set_label('Intensity (Some Units)',fontsize=8)
cbar1.ax.tick_params(labelsize=8)

ax2.set_ticklabel_type("delta",nbins=6)
ax2.axis[:].major_ticklabels.set_fontsize(10)
ax2.set_xlabel(""); ax2.set_ylabel("")
plt.imshow(m101dat,origin='lower',interpolation='nearest',cmap=cmap2)
plt.title("Map scaled to the power p=6,\nwith a break at 0.2",size=8)
cax2 = inset_axes(ax2,width="5%",height="100%",loc=2,bbox_to_anchor=(1.05, 0, 1, 1),bbox_transform=ax2.transAxes,borderpad=0)
cbar2=plt.colorbar(cax=cax2,orientation='vertical',cmap=cmap2)
cbar2.set_label('Intensity (Some Units)',fontsize=8)
cbar2.ax.tick_params(labelsize=8)

ax3.set_ticklabel_type("delta",nbins=6)
ax3.axis[:].major_ticklabels.set_fontsize(10)
ax3.set_xlabel(""); ax3.set_ylabel("")
plt.imshow(m101dat,origin='lower',interpolation='nearest',cmap=cmap3)
plt.title("pscale1=3, pscale2=2,\nbreakpoint=.2",size=8)
cax3 = inset_axes(ax3,width="5%",height="100%",loc=2,bbox_to_anchor=(1.05, 0, 1, 1),bbox_transform=ax3.transAxes,borderpad=0)
cbar3=plt.colorbar(cax=cax3,orientation='vertical',cmap=cmap3)
cbar3.set_label('Intensity (Some Units)',fontsize=8)
cbar3.ax.tick_params(labelsize=8)

ax4.set_ticklabel_type("delta",nbins=6)
ax4.axis[:].major_ticklabels.set_fontsize(10)
ax4.set_xlabel(""); ax4.set_ylabel("")
plt.imshow(m101dat,origin='lower',interpolation='nearest',cmap=cmap4)
plt.title("pscale1=1.2, pscale2=(1/1.5),\nbreakpoint=.5",size=8)
cax4 = inset_axes(ax4,width="5%",height="100%",loc=2,bbox_to_anchor=(1.05, 0, 1, 1),bbox_transform=ax4.transAxes,borderpad=0)
cbar4=plt.colorbar(cax=cax4,orientation='vertical',cmap=cmap4)
cbar4.set_label('Intensity (Some Units)',fontsize=8)
cbar4.ax.tick_params(labelsize=8)

plt.suptitle('M101 with Power-Scaled Colormaps',size=16,x=.5,y=.999)
plt.savefig('cmap_powerscale_images.png',dpi=200,bbox_inches='tight')
plt.show()
plt.clf()


Update: I decided to add another of my favorite colormaps.  This is similar to the STD GAMMA-II color table in IDL.  I find it really nice for astronomical images. It goes from black to blue to red to yellow to white.

cdict_coolheat={
'red'  :  ((0., 0., 0.), (0.25,0.,0.), (0.5,1.,1.), (0.75,1.0,1.0),  (1., 1., 1.)),
'green':  ((0., 0., 0.), (0.25,0.,0.), (0.5,0.,0.), (0.75,1.0,1.0),  (1., 1., 1.)),
'blue' :  ((0., 0., 0.), (0.25,1.,1.), (0.5,0.,0.), (0.75,0.0,0.0),  (1., 1., 1.))
}

coolheat = colors.LinearSegmentedColormap('coolheat', cdict_coolheat,1024)


Here it is in action with the M101 image.  On the left is the simple cmap, on the right I’ve stretched the low end to separate the background from the low-level emission.

Code for plotting M101 with coolheat colormap
import numpy as np
from matplotlib import colors, pyplot as plt
import matplotlib
import pyfits
import pywcsgrid2
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

#Defining the coolheat colormap manually
cdict_coolheat={
'red'  :  ((0., 0., 0.), (0.25,0.,0.), (0.5,1.,1.), (0.75,1.0,1.0),  (1., 1., 1.)),
'green':  ((0., 0., 0.), (0.25,0.,0.), (0.5,0.,0.), (0.75,1.0,1.0),  (1., 1., 1.)),
'blue' :  ((0., 0., 0.), (0.25,1.,1.), (0.5,0.,0.), (0.75,0.0,0.0),  (1., 1., 1.))
}

coolheat = colors.LinearSegmentedColormap('coolheat', cdict_coolheat,1024)

### Stretching the colormap manually
cdict_coolheatstretch={
'red'  :  ((0., 0., 0.), (0.12,0.0,0.0), (0.35,0.,0.), (0.55,1.,1.), (0.8,1.0,1.0),  (1., 1., 1.)),
'green':  ((0., 0., 0.), (0.12,0.0,0.0), (0.35,0.,0.), (0.55,0.,0.), (0.8,1.0,1.0),  (1., 1., 1.)),
'blue' :  ((0., 0., 0.), (0.12,0.1,0.1), (0.35,1.,1.), (0.55,0.,0.), (0.8,0.0,0.0),  (1., 1., 1.))
}
coolheat_stretch = colors.LinearSegmentedColormap('coolheatstretch', cdict_coolheatstretch,1024)

### Alternatively, we can stretch the cmap using the cmap_xmap (from
### http://www.scipy.org/Cookbook/Matplotlib/ColormapTransformations).
### Here, use power=0.8
#
# def cmap_xmap(function,cmap):
#     cdict = cmap._segmentdata.copy()
#     function_to_map = (lambda x : (function(x[0]), x[1], x[2]))
#     for key in ('red','green','blue'):
#         cdict[key] = map(function_to_map, cdict[key])
#         cdict[key].sort()
#         assert (cdict[key][0]<0 or cdict[key][-1]>1), "Resulting indices extend out of the [0, 1] segment.";
#     return matplotlib.colors.LinearSegmentedColormap('colormap',cdict,1024)
#
# coolheat_stretch=cmap_xmap(lambda funcvar: funcvar**(0.5),coolheat)

plt.rc('font',family='serif')
fig=plt.figure(3)
ax1.set_ticklabel_type("delta",nbins=6)
ax1.set_xlabel("$\Delta$RA (J2000)",size=8)
ax1.set_ylabel('$\Delta$DEC (J2000)',size=8)
ax1.axis[:].major_ticklabels.set_fontsize(10)
plt.imshow(m101dat,origin='lower',interpolation='nearest',cmap=coolheat)
plt.title('Basic',size=10)
cax1 = inset_axes(ax1,
width="5%", # width = 5% of parent_bbox width
height="100%", # height : 100%
loc=2,
bbox_to_anchor=(1.02, 0, 1, 1),
bbox_transform=ax1.transAxes,
)
cbar1=plt.colorbar(cax=cax1,orientation='vertical',cmap=coolheat)
cbar1.set_label('Intensity (Some Units)',fontsize=8)
cbar1.ax.tick_params(labelsize=8)

ax2.set_ticklabel_type('delta',nbins=6)
ax2.axis[:].major_ticklabels.set_fontsize(10)
ax2.set_xlabel(""); ax2.set_ylabel("")
plt.imshow(m101dat,origin='lower',interpolation='nearest',cmap=coolheat_stretch)
plt.title('Cmap scaled manually',size=10)
cax2 = inset_axes(ax2,width="5%",height="100%",loc=2,bbox_to_anchor=(1.02, 0, 1, 1),bbox_transform=ax2.transAxes,borderpad=0)
cbar2=plt.colorbar(cax=cax2,orientation='vertical',cmap=coolheat_stretch)
cbar2.set_label('Intensity (Some Units)',fontsize=8)
cbar2.ax.tick_params(labelsize=8)

plt.suptitle('M101 With "coolheat" Colormap',size=14,x=.5,y=.8)
plt.savefig('cmap_coolheat.png',dpi=200,bbox_inches='tight')
plt.show()
plt.clf()



# NaNs

NaN (Not a Number) is an interesting and useful concept.  For many people, NaNs represent errors or bugs in code, but there are many applications for which they are beneficial and convenient.  In this discussion we will make use of numpy masked arrays as we cover a few different cases where NaNs come into play.  You should check out the official webpage at
They have several excellent examples and a much more thorough coverage of the topic.  I do not intend this guide as a replacement, but rather as a brief supplement with additional specific applications.

As always, begin by importing numpy:


import numpy as np



First, I want to demonstrate the proper way to check if a number is an NaN.  You do NOT use the usual double equal sign.  Instead, use np.isnan()  For example:

a = 5
a == 5
# --> True

a = np.nan
a == np.nan # --> False
np.isnan(a) # --> True



For the first example, let’s generate some sample data (a quadratic with random noise) and explore what happens when we insert some NaNs.

#Generate some x-vals, as well as some y-vals with scatter.
xvals = np.arange(-3,5,.2)
yvals = xvals**2 - 7
for i in range( len(yvals) ): yvals[i] += 2*np.random.randn()



If we wanted to perform a simple fit, we could use scipy.polyfit(), since we have defined this as a quadratic.  If we wanted to find the minimum, we could use either of these:
yvals.min()
np.min(yvals)

And life goes on.  But now, what if you have an NaN for one or more of your y-array values?  Let’s make an example array:

yvals_with_nans = yvals.copy()
yvals_with_nans[3:10] = np.nan


You can still add, subtract, multiply, etc. just fine.  Even plotting with plt.plot(xvals,yvals_with_nans) works as expected, but note what happens when you try to determine special properties of the array:

yvals_with_nans.min()
# --> nan

np.min(yvals_with_nans) # --> nan

np.max(yvals_with_nans) # --> nan

np.sum(yvals_with_nans) # --> nan


Luckily, this is an easy fix for this particular situation.  Simply use numpy’s built-in tools for dealing with NaNs:
numpy.nanmin(), numpy.nanmax()

np.nanmin(yvals_with_nans)
# --> -10.287...



But what if you want to use some regular functions on your array?  You can copy over only the parts that don’t have NaNs, but there is a more powerful way that lets you keep the original array while at the same time only using the non-NaN elements in calculations.  For this we use numpy masked arrays.

Masked arrays consist of two layers: the array of data and the array of mask values.  Experiment with the numpy.ma module:

a = np.ma.masked_array([1,2,3,4,5],mask=0)
print a
# --> [1 2 3 4 5], mask = [False False False False False]

np.sum(a) # --> 15

print a
# --> [1 2 -- 4 --], mask = [False False True False True]

print a.mask # --> [False False True False True]
print a.data # --> [1 2 3 4 5]

#Notice how a.data returns the full list of un-masked elements.

np.sum(a) # --> 7
#It only summed the masked elements.


You may also find this trick useful: you can sum up the number of masked/un-masked elements quite easily.

np.sum(a.mask) # --> 2


Note the use of the negative sign.  When you sum the mask, you are summing the True/False values.  Since Booleans are 1/0, the sum of [True True False True] = [1 1 0 1] will be 3.  If you sum the negative of the mask, you are summing [-True -True -False -True] = [False False True False] = [0 0 1 0], which is 1.  We will make use of this in a later example.

Now let’s invoke a special version of the masked array to automatically mask out invalid numbers like NaN and inf: numpy.ma.masked_invalid()

y_masked = np.ma.masked_invalid( yvals_with_nans )

print y_masked #Note how the nans have been replaced by -- marks



Et voila!  Arithmetic operations and np.min(), etc., will work correctly now on the masked array, though polyfit will not.  Be sure to check out the myriad other incdredibly useful masking arrays, such as

•      np.ma.fix_invalid()
•      and many more…

If you simply want to remove the NaNs entirely from your array, the easiest method is to compress the masked array.  This returns a 1-D array with all masked values stricken.  Note that even for 2-D or 3-D arrays (images or cubes), the masked_array.compressed() command will still return the 1-D array of all your un-masked values.

y_compressed = y_masked.compressed() #If you print, you will notice that the NaNs are now gone.



We probably want to keep our x-array the same length (for fitting, plotting, etc.).  We can remove the x-values corresponding to the y-array NaNs easily with the same mask:

x_masked = np.ma.masked_array(xvals, mask=y_masked.mask)
print len(xvals), len(yvals_with_nans), len(x_compressed), len(y_compressed)
# --> 40 40 33 33

#Or you can even just do it on-the-fly with


Now is a good time to insert a warning: Interpolation of missing data points should only be done with caution.  In general, it’s much better practice to perform a fit on your data (see the tutorial on fitting in python here) and get your data points that way.  The following is an example of why.  It can be useful to interpolate data for missing holes in 2-D images, though, if you need to make it continuous.

There are two easy methods for interpolation if you have 1-D data:

This is quick and easy if all your data values are increasing, but it will give you nonsense results if they are not.
Basic usage is numpy.interp(,x_array,y_array)

This is a much more full-featured module, which also has support for 2-D arrays/images.
Basic usage is scipy.interp1d(,x_array,y_array)

Here are examples for interpolating the y-value at index 2.5 (numpy) and at x=-2.5 (scipy) of the compressed arrays:

interpvals1 = np.interp(2.5, x_compressed, y_compressed) # --> -0.6589... (for my particular randomly-generated array)

from scipy import interpolate
interpvals2 = interpolate.interp1d(x_compressed, y_compressed, kind='cubic')
interpvals2( -2.5 )



Here is a plot demonstrating why blind interpolation can be dangerous.  The interpolated value can differ by quite a bit from the expected best fit.  Of course, in order to fit a curve, you must be sure of the function’s form (polynomial, exponential, etc.).

Interpolation Plot Code
import numpy as np
from scipy import polyfit
from matplotlib import pyplot as plt

plt.rc('font',family='serif') #Just because I prefer serif fonts on my plots

#Making the arrays as we did above
xvals=np.arange(-3,5,.2)
yvals_with_nans=xvals**2-7
for i in range( len(yvals) ): yvals_with_nans[i]+=2*np.random.randn()
yvals_with_nans[3:10]=np.nan

interpval=np.interp(2.5,x_compressed,y_compressed) #x-val is -1.8
fitvals=polyfit(x_compressed,y_compressed,2)
yfit=fitvals[0]*xvals**2+fitvals[1]*xvals+fitvals[2] #Note that I am using xvals here to smoothly fill in the empty region

#Actually making the figure
fig1=plt.figure(1)
plt.plot(xvals,yvals_with_nans,'.',color='#153E7E')
plt.xlabel('x'); plt.ylabel('y')
plt.xlim(-3.5,5.5)

plt.plot(x_compressed,y_compressed,'.',color='#347C17',label='Compressed Array')
plt.plot(-1.8,interpval,'s',color='#990000',label='Interpolated Point')
leg2=plt.legend(loc=2,prop={'size':8},fancybox=True,handlelength=2.5,numpoints=1)
plt.xlabel('x'); plt.ylabel('y')
plt.xlim(-3.5,5.5)
plt.title('Fit on y_compressed',size=10)

#Adding text boxes below the plots with some diagnostic info

ax2.text(0,-.2,'Best fit:  y = %.2f x$^2$ + %.2f x + %.2f\nTrue function: y = x$^2$ - 7'%(fitvals[0],fitvals[1],fitvals[2]),ha='left',va='top',transform=ax2.transAxes,size=8,bbox=dict(boxstyle='round',fc='w'))

plt.suptitle('Sample Arrays with NaNs',size=16)

#plt.show()
plt.savefig('NaN_plot_example1.png',bbox_inches='tight')
plt.clf()


## Images and NaNs

===========================

Many astronomical images have NaNs for valid reasons – most often it’s because we blank everything outside of some desired region, whether it be a photometry aperture, the edge of the a detector’s field-of-view, etc…  But often we want to perform arithmetic on an image with NaNs, which would normally make the routines give NaN answers. The methods presented above for 1-D cases extend to 2-D images as well.

Let’s use a real galaxy image as an example.  You can find the LITTLE THINGS VLA neutral hydrogen velocity (Moment-1) map for CVnIdwA here (click Save Link As; it’s ~1MB).  You can learn more about the LITTLE THINGS project at the NRAO site here:
https://science.nrao.edu/science/surveys/littlethings

This image contains the velocity values in m/s of the galaxy, determined from the beam-smoothed first moment.  All the values outside of the galaxy’s emission have been blanked with NaNs.  See Hunter, et al., 2012 in AJ for more details on the data processing.

Let’s calculate the rms value of the velocity values in all of the pixels.  Recall that RMS=sqrt(values^2/number_summed)   To make the process easier to follow, we’ll break it up into parts:

import pyfits

#Import the data with pyfits.  It still has additional dimensions, so we will only load the ones we care about - the last two
#veldata = pyfits.getdata('<path to where you saved the .FITS file>/CVnIdwA_NA_XMOM1.FITS')[0,0,:,:]
veldata = pyfits.getdata('CVnIdwA_NA_XMOM1.FITS')[0,0,:,:]

#We will sum all the non-NaN pixels.  We can calculate the number of pixels to be used in a number of ways.

# -- OR -- #

rms = np.sqrt( np.sum(veldata_masked**2) / n_used)
# --> pixrms = 311076.84 m/s

# -- OR an example of on-the-fly calculation -- #
rms_v2 = np.sqrt( np.nansum(veldata**2) / np.count_nonzero( np.ma.masked_invalid(veldata).compressed() ))

#Now if we want the RMS of the difference from the mean:
# --> The RMS difference of a pixel's velocity from the mean value is 9550.46 m/s


Note that we could also have done simple RMS calculation without masked arrays as shown above.  But there is no numpy.nanmean() to use for the RMS difference calculation. This is one example of the versatility of masked arrays.

If we simply want to set the NaNs to a number (say, 0) for processing in some later routine, we can do it easily with np.ma.fix_invalid()

veldata_fix_temp = np.ma.fix_invalid(veldata,fill_value=0)
#Then call the fixed array with:
veldata_fixed = veldata_fix_temp.data

#Or, just do it on-the-fly:
veldata_fixed = np.ma.fix_invalid(veldata,fill_value=0).data
np.max(veldata_fixed) # --> Max = 338740.47 m/s


You may want to leave your image with NaNs instead of filling with 0s, depending on what you’re interested in showing.  If you are plotting with pyplot.imshow() and you want to emphasize the regions that are masked, you can do this quite effectively with a clever choice of color map.  It is also a convenient trick if you want the background transparent instead of black for publications, etc.  It’s a matter of preference, really.  See the plots below. (The code to produce them is there as well.)  For a primer on plotting with kapteyn and pywcsgrid2, see my tutorial here.[link] Replace Plot NaNs with 0s

from matplotlib import pyplot as plt

plt.rc('font',family='serif') #Just because I prefer serif fonts on my plots

veldata=pyfits.getdata('CVnIdwA_NA_XMOM1.FITS')[0,0,:,:]
veldata_fixed = np.ma.fix_invalid(veldata,fill_value=0).data

#To plot wcs coordinates nicely, use the kapteyn or pywcsgrid2 modules.  I demonstrate use of pywcsgrid2 here.
#If you simply wanted to plot the data as you would any other image (without WCS coords), you could use the following code instead:
#
# fig2=plt.figure(2)
# plt.imshow(veldata,origin='lower',interpolation='nearest')
#
# plt.imshow(veldata_fixed,origin='lower',interpolation='nearest')
#

import pywcsgrid2

fig2=plt.figure(2)
plt.imshow(veldata[110:390,110:390],origin='lower',interpolation='nearest',cmap='gist_heat')
plt.title('With NaNs')
ax1.text(.05,.05,'RMS vel. = %.2e m/s'%rms,ha='left',va='bottom',transform=ax1.transAxes,size=8,bbox=None)

plt.title('NaNs $\\rightarrow$ 0')