Class AllToAll<T extends TType>

java.lang.Object
org.tensorflow.op.RawOp
org.tensorflow.op.tpu.AllToAll<T>
All Implemented Interfaces:
Shaped, Op, Operand<T>

@Operator(group="tpu") public final class AllToAll<T extends TType> extends RawOp implements Operand<T>
An Op to exchange data across TPU replicas. On each replica, the input is split into split_count blocks along split_dimension and send to the other replicas given group_assignment. After receiving split_count - 1 blocks from other replicas, we concatenate the blocks along concat_dimension as the output.

For example, suppose there are 2 TPU replicas: replica 0 receives input: [[A, B]] replica 1 receives input: [[C, D]]

group_assignment=[[0, 1]] concat_dimension=0 split_dimension=1 split_count=2

replica 0's output: [[A], [C]] replica 1's output: [[B], [D]]

  • Field Details

  • Constructor Details

    • AllToAll

      public AllToAll(Operation operation)
  • Method Details

    • create

      @Endpoint(describeByClass=true) public static <T extends TType> AllToAll<T> create(Scope scope, Operand<T> input, Operand<TInt32> groupAssignment, Long concatDimension, Long splitDimension, Long splitCount)
      Factory method to create a class wrapping a new AllToAll operation.
      Type Parameters:
      T - data type for AllToAll output and operands
      Parameters:
      scope - current scope
      input - The local input to the sum.
      groupAssignment - An int32 tensor with shape [num_groups, num_replicas_per_group]. group_assignment[i] represents the replica ids in the ith subgroup.
      concatDimension - The dimension number to concatenate.
      splitDimension - The dimension number to split.
      splitCount - The number of splits, this number must equal to the sub-group size(group_assignment.get_shape()[1])
      Returns:
      a new instance of AllToAll
    • output

      public Output<T> output()
      Gets output. The exchanged result.
      Returns:
      output.
    • asOutput

      public Output<T> asOutput()
      Description copied from interface: Operand
      Returns the symbolic handle of the tensor.

      Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is used to obtain a symbolic handle that represents the computation of the input.

      Specified by:
      asOutput in interface Operand<T extends TType>
      See Also: