diff --git a/modest.py b/modest.py new file mode 100644 index 0000000..8cf082f --- /dev/null +++ b/modest.py @@ -0,0 +1,291 @@ +""" +Modification of Chris Beaumont's mpl-modest-image package to allow the use of +set_extent. +""" +from __future__ import print_function, division + +import matplotlib +rcParams = matplotlib.rcParams + +import matplotlib.image as mi +import matplotlib.colors as mcolors +import matplotlib.cbook as cbook +from matplotlib.transforms import IdentityTransform, Affine2D + +import numpy as np + +IDENTITY_TRANSFORM = IdentityTransform() + + +class ModestImage(mi.AxesImage): + + """ + Computationally modest image class. + ModestImage is an extension of the Matplotlib AxesImage class + better suited for the interactive display of larger images. Before + drawing, ModestImage resamples the data array based on the screen + resolution and view window. This has very little affect on the + appearance of the image, but can substantially cut down on + computation since calculations of unresolved or clipped pixels + are skipped. + The interface of ModestImage is the same as AxesImage. However, it + does not currently support setting the 'extent' property. There + may also be weird coordinate warping operations for images that + I'm not aware of. Don't expect those to work either. + """ + + def __init__(self, *args, **kwargs): + self._full_res = None + self._full_extent = kwargs.get('extent', None) + super(ModestImage, self).__init__(*args, **kwargs) + self.invalidate_cache() + + def set_data(self, A): + """ + Set the image array + ACCEPTS: numpy/PIL Image A + """ + self._full_res = A + self._A = A + + if self._A.dtype != np.uint8 and not np.can_cast(self._A.dtype, + np.float): + raise TypeError("Image data can not convert to float") + + if (self._A.ndim not in (2, 3) or + (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))): + raise TypeError("Invalid dimensions for image data") + + self.invalidate_cache() + + def invalidate_cache(self): + self._bounds = None + self._imcache = None + self._rgbacache = None + self._oldxslice = None + self._oldyslice = None + self._sx, self._sy = None, None + self._pixel2world_cache = None + self._world2pixel_cache = None + + def set_extent(self, extent): + self._full_extent = extent + self.invalidate_cache() + mi.AxesImage.set_extent(self, extent) + + def get_array(self): + """Override to return the full-resolution array""" + return self._full_res + + @property + def _pixel2world(self): + + if self._pixel2world_cache is None: + + # Pre-compute affine transforms to convert between the 'world' + # coordinates of the axes (what is shown by the axis labels) to + # 'pixel' coordinates in the underlying array. + + extent = self._full_extent + + if extent is None: + + self._pixel2world_cache = IDENTITY_TRANSFORM + + else: + + self._pixel2world_cache = Affine2D() + + self._pixel2world.translate(+0.5, +0.5) + + self._pixel2world.scale((extent[1] - extent[0]) / self._full_res.shape[1], + (extent[3] - extent[2]) / self._full_res.shape[0]) + + self._pixel2world.translate(extent[0], extent[2]) + + self._world2pixel_cache = None + + return self._pixel2world_cache + + @property + def _world2pixel(self): + if self._world2pixel_cache is None: + self._world2pixel_cache = self._pixel2world.inverted() + return self._world2pixel_cache + + def _scale_to_res(self): + """ + Change self._A and _extent to render an image whose resolution is + matched to the eventual rendering. + """ + + # Find out how we need to slice the array to make sure we match the + # resolution of the display. We pass self._world2pixel which matters + # for cases where the extent has been set. + x0, x1, sx, y0, y1, sy = extract_matched_slices(axes=self.axes, + shape=self._full_res.shape, + transform=self._world2pixel) + + # Check whether we've already calculated what we need, and if so just + # return without doing anything further. + if (self._bounds is not None and + sx >= self._sx and sy >= self._sy and + x0 >= self._bounds[0] and x1 <= self._bounds[1] and + y0 >= self._bounds[2] and y1 <= self._bounds[3]): + return + + # Slice the array using the slices determined previously to optimally + # match the display + self._A = self._full_res[y0:y1:sy, x0:x1:sx] + self._A = cbook.safe_masked_invalid(self._A) + + # We now determine the extent of the subset of the image, by determining + # it first in pixel space, and converting it to the 'world' coordinates. + + # See https://github.com/matplotlib/matplotlib/issues/8693 for a + # demonstration of why origin='upper' and extent=None needs to be + # special-cased. + + if self.origin == 'upper' and self._full_extent is None: + xmin, xmax, ymin, ymax = x0 - .5, x1 - .5, y1 - .5, y0 - .5 + else: + xmin, xmax, ymin, ymax = x0 - .5, x1 - .5, y0 - .5, y1 - .5 + + xmin, ymin, xmax, ymax = self._pixel2world.transform([(xmin, ymin), (xmax, ymax)]).ravel() + + mi.AxesImage.set_extent(self, [xmin, xmax, ymin, ymax]) + # self.set_extent([xmin, xmax, ymin, ymax]) + + # Finally, we cache the current settings to avoid re-computing similar + # arrays in future. + self._sx = sx + self._sy = sy + self._bounds = (x0, x1, y0, y1) + + self.changed() + + def draw(self, renderer, *args, **kwargs): + if self._full_res.shape is None: + return + self._scale_to_res() + super(ModestImage, self).draw(renderer, *args, **kwargs) + + +def main(): + from time import time + import matplotlib.pyplot as plt + x, y = np.mgrid[0:20000, 0:20000] + data = np.sin(x / 10.) * np.cos(y / 30.) + + f = plt.figure() + ax = f.add_subplot(111) + + # try switching between + artist = ModestImage(ax, data=data) + + ax.set_aspect('equal') + artist.norm.vmin = -1 + artist.norm.vmax = 1 + + ax.add_artist(artist) + + t0 = time() + plt.gcf().canvas.draw() + t1 = time() + + print("Draw time for %s: %0.1f ms" % (artist.__class__.__name__, + (t1 - t0) * 1000)) + + plt.show() + + +def imshow(axes, X, cmap=None, norm=None, aspect=None, + interpolation=None, alpha=None, vmin=None, vmax=None, + origin=None, extent=None, shape=None, filternorm=1, + filterrad=4.0, imlim=None, resample=None, url=None, **kwargs): + """Similar to matplotlib's imshow command, but produces a ModestImage + Unlike matplotlib version, must explicitly specify axes + """ + # if not axes._hold: + # axes.cla() + if norm is not None: + assert(isinstance(norm, mcolors.Normalize)) + if aspect is None: + aspect = rcParams['image.aspect'] + axes.set_aspect(aspect) + im = ModestImage(axes, cmap=cmap, norm=norm, interpolation=interpolation, + origin=origin, extent=extent, filternorm=filternorm, + filterrad=filterrad, resample=resample, **kwargs) + + im.set_data(X) + im.set_alpha(alpha) + axes._set_artist_props(im) + + if im.get_clip_path() is None: + # image does not already have clipping set, clip to axes patch + im.set_clip_path(axes.patch) + + # if norm is None and shape is None: + # im.set_clim(vmin, vmax) + if vmin is not None or vmax is not None: + im.set_clim(vmin, vmax) + elif norm is None: + im.autoscale_None() + + im.set_url(url) + + # update ax.dataLim, and, if autoscaling, set viewLim + # to tightly fit the image, regardless of dataLim. + im.set_extent(im.get_extent()) + + axes.images.append(im) + im._remove_method = lambda h: axes.images.remove(h) + + return im + + +def extract_matched_slices(axes=None, shape=None, extent=None, + transform=IDENTITY_TRANSFORM): + """Determine the slice parameters to use, matched to the screen. + :param ax: Axes object to query. It's extent and pixel size + determine the slice parameters + :param shape: Tuple of the full image shape to slice into. Upper + boundaries for slices will be cropped to fit within + this shape. + :rtype: tulpe of x0, x1, sx, y0, y1, sy + Indexing the full resolution array as array[y0:y1:sy, x0:x1:sx] returns + a view well-matched to the axes' resolution and extent + """ + + # Find extent in display pixels (this gives the resolution we need + # to sample the array to) + ext = (axes.transAxes.transform([(1, 1)]) - axes.transAxes.transform([(0, 0)]))[0] + + # Find the extent of the axes in 'world' coordinates + xlim, ylim = axes.get_xlim(), axes.get_ylim() + + # Transform the limits to pixel coordinates + ind0 = transform.transform([min(xlim), min(ylim)]) + ind1 = transform.transform([max(xlim), max(ylim)]) + + def _clip(val, lo, hi): + return int(max(min(val, hi), lo)) + + # Determine the range of pixels to extract from the array, including a 5 + # pixel margin all around. We ensure that the shape of the resulting array + # will always be at least (1, 1) even if there is really no overlap, to + # avoid issues. + y0 = _clip(ind0[1] - 5, 0, shape[0] - 1) + y1 = _clip(ind1[1] + 5, 1, shape[0]) + x0 = _clip(ind0[0] - 5, 0, shape[1] - 1) + x1 = _clip(ind1[0] + 5, 1, shape[1]) + + # Determine the strides that can be used when extracting the array + sy = int(max(1, min((y1 - y0) / 5., np.ceil(abs((ind1[1] - ind0[1]) / ext[1]))))) + sx = int(max(1, min((x1 - x0) / 5., np.ceil(abs((ind1[0] - ind0[0]) / ext[0]))))) + + return x0, x1, sx, y0, y1, sy + + +if __name__ == "__main__": + main() \ No newline at end of file