Source code for lingvo.core.sendrecv

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Send/Recv ops.

The following _Send()/_Recv() are adapted from python op wrappers
generated by python_op_gen_main. python_op_gen_main.cc's
PrintAllPythonOps needs to be updated to export internal ops.
"""

from lingvo import compat as tf

# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla
# pylint: enable=g-direct-tensorflow-import


[docs]def _TpuCore(device): """Returns the TPU core represented by <device>, or -1 if not TPU.""" prefix = "device:TPU_REPLICATED_CORE:" if prefix in device: return int(device[len(prefix):]) return -1
[docs]class Channel: """A communication channel to transfer tensors in order.""" def __init__(self, dtype, shape, send_device, recv_device, name=None): """Construct a channel. Args: dtype: The dtype of tensors sent through the channel. shape: The shape of tensors sent through the channel. Must be a fully defined shape for TPUs. send_device: A fully-specified tensorflow device. recv_device: A fully-specified tensorflow device. name: A name for the channel (optional). """ current_graph = tf.get_default_graph() assert current_graph, "A channel is scoped within a tf.Graph" self._dtype = dtype self._send_device = send_device self._recv_device = recv_device self._name = current_graph.unique_name(name if name else "channel") assert shape is not None shape = tf.TensorShape(shape) self._shape = shape self._send_tpu_core = _TpuCore(send_device) self._recv_tpu_core = _TpuCore(recv_device) self._send_called = False self._recv_op = None assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), ( "Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device)) if self._send_tpu_core >= 0: assert self._shape.is_fully_defined(), ( "TPU channel must have fully defined shape. Name: %s, shape: %s" % (self._name, self._shape)) assert self._send_tpu_core != self._recv_tpu_core, ( "TPU send/recv must be cross-core: %s and %s" % (send_device, recv_device))
[docs] def Send(self, tensor): """Sends a tensor through the channel.""" assert tensor.dtype == self._dtype assert not self._send_called, ("Send called multiple times for %s" % self._name) self._send_called = True if self._send_tpu_core == -1: return tf.raw_ops.Send( tensor=tensor, tensor_name=self._name, send_device=self._send_device, send_device_incarnation=0, recv_device=self._recv_device) else: with tf.device(self._send_device): return xla.send( tensor, tensor_name=self._name, name="Send_" + self._name)
[docs] def Recv(self): """Receives a tensor from the channel.""" if self._send_tpu_core == -1: received = tf.raw_ops.Recv( tensor_type=self._dtype, tensor_name=self._name, send_device=self._send_device, send_device_incarnation=0, recv_device=self._recv_device) received.set_shape(self._shape) return received else: with tf.device(self._recv_device): return xla.recv( self._dtype, tensor_name=self._name, shape=self._shape, name="Recv_" + self._name)