Map2D.py 2.19 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
from numpy import array

class Map2D:
    """
    Class to map and remap multidimensional data in 2 dimensions, referred to as Row and Col dimension. Requires input data and the dimensions, that will be the rows/columns of the matrix according to the axis parameter.
    """
    def __init__(self,data,dims,axis=0):
        """Initialise Map2D object.

        Args:
            data (numpy.array): input data array.
            dims (integer or list of integers): provides the dimensions to collaps into rows (axis=0) or columns (axis=1).
            axis (integer): collaps dimensions into row (0) or column (1).
        """
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
        rdims = (type(dims)==type(0) and [dims] or list(dims))
        Shape=list(data.shape)
        Len=len(Shape)
        cdims=list(range(Len))
        for n,d in enumerate(rdims):
            if d<0:
                d=Len+d
                rdims[n]=d
            cdims.remove(d)
        rdims.sort()
        if axis==1:
            coldims=tuple(rdims)
            rowdims=tuple(cdims)
        else:
            coldims=tuple(cdims)
            rowdims=tuple(rdims)
        self.rowShape=tuple([data.shape[d] for d in rowdims])
        self.columnShape=tuple([data.shape[d] for d in coldims])
        self.rows=array(self.rowShape).prod()
        self.columns=array(self.columnShape).prod()
        self.map=(rowdims+coldims)
        remap=list(range(len(self.rowShape)))
        for n in range(len(self.columnShape)):
            remap.insert(coldims[n],n+len(self.rowShape))
        self.remap=tuple(remap)
40 41 42 43 44 45 46 47

    def __call__(self,data):
        """
        Reproject matrix like 2D-array onto original dimensions.
        Args:
            data (numpy.array): input data, needs to have same dimensions
                                as original input data.
        """
48
        return data.transpose(self.map).reshape((self.rows,self.columns,))
49 50 51 52 53 54 55 56

    inv = lambda self,data: data.reshape(self.rowShape+self.columnShape).transpose(self.remap)

    X = lambda self,x:x.reshape((-1,)+(self.rows,)).squeeze()
    Xinv = lambda self,x:x.reshape((-1,)+self.rowShape).squeeze()

    Y = lambda self,y:y.reshape((-1,)+(self.columns,)).squeeze()
    Yinv = lambda self,y:y.reshape((-1,)+self.columnShape).squeeze()