Class AllToAll<T extends TType>
java.lang.Object
org.tensorflow.op.RawOp
org.tensorflow.op.tpu.AllToAll<T>
@Operator(group="tpu")
public final class AllToAll<T extends TType>
extends RawOp
implements Operand<T>
An Op to exchange data across TPU replicas.
On each replica, the input is split into
split_count blocks along
split_dimension and send to the other replicas given group_assignment. After
receiving split_count - 1 blocks from other replicas, we concatenate the
blocks along concat_dimension as the output.
For example, suppose there are 2 TPU replicas:
replica 0 receives input: [[A, B]]
replica 1 receives input: [[C, D]]
group_assignment=[[0, 1]]
concat_dimension=0
split_dimension=1
split_count=2
replica 0's output: [[A], [C]]
replica 1's output: [[B], [D]]
-
Nested Class Summary
Nested Classes -
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final StringThe name of this op, as known by TensorFlow core engine -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionasOutput()Returns the symbolic handle of the tensor.create(Scope scope, Operand<T> input, Operand<TInt32> groupAssignment, Long concatDimension, Long splitDimension, Long splitCount) Factory method to create a class wrapping a new AllToAll operation.output()Gets output.
-
Field Details
-
OP_NAME
The name of this op, as known by TensorFlow core engine- See Also:
-
-
Constructor Details
-
AllToAll
-
-
Method Details
-
create
@Endpoint(describeByClass=true) public static <T extends TType> AllToAll<T> create(Scope scope, Operand<T> input, Operand<TInt32> groupAssignment, Long concatDimension, Long splitDimension, Long splitCount) Factory method to create a class wrapping a new AllToAll operation.- Type Parameters:
T- data type forAllToAlloutput and operands- Parameters:
scope- current scopeinput- The local input to the sum.groupAssignment- An int32 tensor with shape [num_groups, num_replicas_per_group].group_assignment[i]represents the replica ids in the ith subgroup.concatDimension- The dimension number to concatenate.splitDimension- The dimension number to split.splitCount- The number of splits, this number must equal to the sub-group size(group_assignment.get_shape()[1])- Returns:
- a new instance of AllToAll
-
output
-
asOutput
Description copied from interface:OperandReturns the symbolic handle of the tensor.Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is used to obtain a symbolic handle that represents the computation of the input.
-