Class Shape

java.lang.Object
org.flag4j.arrays.Shape
All Implemented Interfaces:
Serializable

public class Shape extends Object implements Serializable
Represents the shape of a multidimensional array (e.g. tensor, matrix, vector, etc.), specifying its dimensions and providing utilities for shape-related ops.

A shape is defined by an array of dimensions, where each dimension specifies the size of the tensor along a particular axis. Strides can also be computed for the shape which specify the number of data to step in each dimension of the shape when traversing an array with the given shape. Strides will always be row-major contiguous and allow for efficient array traversal and mapping of nD indices to 1D contiguous indices.

This class also supports converting between multidimensional and flat indices, computing the shapes rank (i.e. number of dimensions), computing the total number of data of an array with the given shape, and manipulating dimensions through swaps or permutations.

The Shape class is immutable with respect to its dimensions, ensuring thread safety and consistency. Strides are computed lazily only when needed to minimize overhead.

Example usage: Link icon


 Shape shape = new Shape();  // Creates a shape for a scalar value.
 shape = new Shape(3, 4, 5);  // Creates a shape for a 3x4x5 tensor.
 int rank = shape.getRank();  // Gets the rank (number of dimensions).
 int[] strides = shape.getStrides();  // Retrieves the strides for this shape.
 int flatIndex = getFlatIndex(2, 1, 4);  // Converts multidimensional indices to a flat index.
 int[] multiDimIndex = shape.getNdIndices(56);  // Flat index to nD index: {2, 3, 1}.
 
 
See Also:
  • Constructor Summary Link icon

    Constructors
    Constructor
    Description
    Shape(int... dims)
    Constructs a shape object from specified dimensions.
  • Method Summary Link icon

    Modifier and Type
    Method
    Description
    boolean
    Checks if an object is equal to this shape.
    Flattens this shape to a rank-1 shape with dimension equal to the product of all of this shape's dimensions.
    int
    get(int i)
    Get the size of the shape object in the specified dimension.
    int
    get1DIndex(int... nDIndex)
    Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.
    int[]
    Gets the shape of a tensor as an array of dimensions.
    int[]
    getNdIndices(int index)
    Efficiently computes the nD tensor index based on a 1D index from the internal 1D data array.
    int[][]
    getNdIndices(int... indices)
    Efficiently computes the nD tensor indices from multiple 1D indices from the internal 1D data array.
    int
    Gets the rank of a tensor with this shape.
    int[]
    Gets the strides of this shape as an array.
    int
    Generates the hashcode for this shape object.
    boolean
    Checks if the total number of elements represented by this shape can be represented as a 32-bit integer without overflowing.
    boolean
    Checks if this shape is square.
    permuteAxes(int... axes)
    Permutes the axes of this shape.
    slice(int startIdx)
    Returns a slice of this shape starting from the specified index to the end of this shape's dimensions.
    slice(int startIdx, int stopIdx)
    Returns a slice of this shape from the specified start index to the stop index of this shape's dimensions.
    swapAxes(int axis1, int axis2)
    Swaps two axes of this shape.
    Converts this Shape object to a string format.
    Gets the total number of data for a tensor with this shape.
    int
    Gets the total number of data for a tensor with this shape.
    long
    Gets the total number of data for a tensor with this shape as a long.
    int
    unsafeGet1DIndex(int... nDIndex)
    Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.
    unsafePermuteAxes(int... axes)
    Permutes the axes of this shape.

    Methods inherited from class java.lang.Object Link icon

    clone, finalize, getClass, notify, notifyAll, wait, wait, wait
  • Constructor Details Link icon

    • Shape Link icon

      public Shape(int... dims)
      Constructs a shape object from specified dimensions.
      Parameters:
      dims - A list of the dimension measurements for this shape object. All data must be non-negative.
      Throws:
      IllegalArgumentException - If any dimension is negative.
  • Method Details Link icon

    • getRank Link icon

      public int getRank()
      Gets the rank of a tensor with this shape.
      Returns:
      The rank for a tensor with this shape.
    • getDims Link icon

      public int[] getDims()
      Gets the shape of a tensor as an array of dimensions.
      Returns:
      Shape of a tensor as an integer array.
    • getStrides Link icon

      public int[] getStrides()
      Gets the strides of this shape as an array. Strides are the step sizes needed to move from one element to another along each axis in the tensor.
      Returns:
      The strides of this shape as an integer array.
    • get Link icon

      public int get(int i)
      Get the size of the shape object in the specified dimension.
      Parameters:
      i - Dimension to get the size of.
      Returns:
      The size of this shape object in the specified dimension.
    • slice Link icon

      public Shape slice(int startIdx)
      Returns a slice of this shape starting from the specified index to the end of this shape's dimensions.
      Parameters:
      startIdx - The starting index for slicing (inclusive).
      Returns:
      A new Shape object containing the dimensions from startIdx to the end dimension.
      Throws:
      IndexOutOfBoundsException - If startIdx is out of bounds of the rank of this shape.
    • slice Link icon

      public Shape slice(int startIdx, int stopIdx)
      Returns a slice of this shape from the specified start index to the stop index of this shape's dimensions.
      Parameters:
      startIdx - The starting index for slicing (inclusive).
      stopIdx - The stopping index for slicing (exclusive).
      Returns:
      A new Shape object containing the dimensions from startIdx to stopIdx.
      Throws:
      IndexOutOfBoundsException - If startIdx or stopIdx is out of bounds.
      IllegalArgumentException - If startIdx > stopIdx.
    • flatten Link icon

      public Shape flatten()
      Flattens this shape to a rank-1 shape with dimension equal to the product of all of this shape's dimensions.
      Returns:
      A rank-1 shape with dimension equal to the product of all of this shape's dimensions.
      Throws:
      ArithmeticException - If the product of this shape's dimensions is too large to be stored in a 32-bit integer.
    • get1DIndex Link icon

      public int get1DIndex(int... nDIndex)
      Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.
      Parameters:
      nDIndex - nD index within a tensor with this shape.
      Returns:
      The 1D index of the element at the specified nD index in the 1D data array of a dense tensor.
      Throws:
      IllegalArgumentException - If the number of indices does not match the rank of this shape.
      IndexOutOfBoundsException - If any index does not fit within a tensor with this shape.
      See Also:
    • unsafeGet1DIndex Link icon

      public int unsafeGet1DIndex(int... nDIndex)

      Computes the index of the 1D data array for a dense tensor from nD indices for a tensor with this shape.

      Warning: Unlike get1DIndex(int...), this method does not perform bounds checking on indices. This can lead to exceptions being thrown or possibly no exception but incorrect results if indices are not valid indices.

      Parameters:
      nDIndex - Indices of tensor with this shape.
      Returns:
      The index of the element at the specified indices in the 1D data array of a dense tensor.
      Throws:
      IllegalArgumentException - If the number of indices does not match the rank of this shape.
      IndexOutOfBoundsException - If any index does not fit within a tensor with this shape.
      See Also:
    • getNdIndices Link icon

      public int[] getNdIndices(int index)
      Efficiently computes the nD tensor index based on a 1D index from the internal 1D data array.
      Parameters:
      index - Index of internal 1D data array.
      Returns:
      The multidimensional indices corresponding to the 1D data array index. This will be an array of integers with length equal to the rank of this shape.
      See Also:
    • getNdIndices Link icon

      public int[][] getNdIndices(int... indices)
      Efficiently computes the nD tensor indices from multiple 1D indices from the internal 1D data array.
      Parameters:
      indices - Array of 1D indices.
      Returns:
      The multidimensional indices corresponding to the 1D data array index. This will be an array of integers with length equal to the rank of this shape.
      See Also:
    • swapAxes Link icon

      public Shape swapAxes(int axis1, int axis2)
      Swaps two axes of this shape. If this shape has had its strides computed, then new strides will also be computed for the resulting shape.
      Parameters:
      axis1 - First axis to swap.
      axis2 - Second axis to swap.
      Returns:
      A copy of this shape with the specified axis swapped.
      Throws:
      ArrayIndexOutOfBoundsException - If either axis is not within [0, rank-1].
      See Also:
    • permuteAxes Link icon

      public Shape permuteAxes(int... axes)
      Permutes the axes of this shape.
      Parameters:
      axes - New axes permutation for the shape. This must be a permutation of {1, 2, 3, ... N} where N is the rank of this shape.
      Returns:
      Returns this shape.
      Throws:
      ArrayIndexOutOfBoundsException - If axes is not a permutation of {1, 2, 3, ... N}.
      See Also:
    • unsafePermuteAxes Link icon

      public Shape unsafePermuteAxes(int... axes)

      Permutes the axes of this shape.

      Warning: Unlike permuteAxes(int...), this method does not perform bounds checking on axes or ensure that axes is a permutation of {1, 2, 3, ... n}. This may result in unexpected behavior if tempDims is malformed.

      Parameters:
      axes - New axes permutation for the shape. This must be a permutation of {1, 2, 3, ... n} where n is the rank of this shape.
      Returns:
      Returns this shape.
      See Also:
    • totalEntries Link icon

      public BigInteger totalEntries()
      Gets the total number of data for a tensor with this shape.
      Returns:
      The total number of data for a tensor with this shape.
      See Also:
    • totalEntriesIntValueExact Link icon

      public int totalEntriesIntValueExact()

      Gets the total number of data for a tensor with this shape. If the total number of data exceeds Integer.MAX_VALUE, an exception is thrown.

      This method is likely to be more efficient than totalEntries() if a primitive int value is desired.

      Returns:
      The total number of data for a tensor with this shape.
      Throws:
      ArithmeticException - If the total number of data overflows a primitive int.
      See Also:
    • totalEntriesLongValueExact Link icon

      public long totalEntriesLongValueExact()

      Gets the total number of data for a tensor with this shape as a long. If the total number of data exceeds Long.MAX_VALUE, an exception is thrown.

      Returns:
      The total number of data for a tensor with this shape.
      Throws:
      ArithmeticException - If the total number of data overflows a primitive int.
      See Also:
    • isIntSized Link icon

      public boolean isIntSized()
      Checks if the total number of elements represented by this shape can be represented as a 32-bit integer without overflowing.
      Returns:
      true if the total number of elements represented by this shape can be represented as a 32-bit integer without overflowing; false if it would overflow.
    • isSquare Link icon

      public boolean isSquare()
      Checks if this shape is square. That is, if all dimensions of this shape are equal.
      Returns:
      true if all dimensions of this shape are equal; false otherwise.
    • equals Link icon

      public boolean equals(Object b)
      Checks if an object is equal to this shape.
      Overrides:
      equals in class Object
      Parameters:
      b - Object to compare with this shape.
      Returns:
      True if d is a Shape object and equal to this shape.
    • hashCode Link icon

      public int hashCode()
      Generates the hashcode for this shape object. This is computed by passing the dims array of this shape object to Arrays.hashCode(int[]).
      Overrides:
      hashCode in class Object
      Returns:
      The hashcode for this array object.
    • toString Link icon

      public String toString()
      Converts this Shape object to a string format.
      Overrides:
      toString in class Object
      Returns:
      The string representation for this Shape object.