# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Send/Recv ops.
The following _Send()/_Recv() are adapted from python op wrappers
generated by python_op_gen_main. python_op_gen_main.cc's
PrintAllPythonOps needs to be updated to export internal ops.
"""
from lingvo import compat as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla
# pylint: enable=g-direct-tensorflow-import
[docs]def _TpuCore(device):
"""Returns the TPU core represented by <device>, or -1 if not TPU."""
prefix = "device:TPU_REPLICATED_CORE:"
if prefix in device:
return int(device[len(prefix):])
return -1
[docs]class Channel:
"""A communication channel to transfer tensors in order."""
def __init__(self, dtype, shape, send_device, recv_device, name=None):
"""Construct a channel.
Args:
dtype: The dtype of tensors sent through the channel.
shape: The shape of tensors sent through the channel. Must be a fully
defined shape for TPUs.
send_device: A fully-specified tensorflow device.
recv_device: A fully-specified tensorflow device.
name: A name for the channel (optional).
"""
current_graph = tf.get_default_graph()
assert current_graph, "A channel is scoped within a tf.Graph"
self._dtype = dtype
self._send_device = send_device
self._recv_device = recv_device
self._name = current_graph.unique_name(name if name else "channel")
assert shape is not None
shape = tf.TensorShape(shape)
self._shape = shape
self._send_tpu_core = _TpuCore(send_device)
self._recv_tpu_core = _TpuCore(recv_device)
self._send_called = False
self._recv_op = None
assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), (
"Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device))
if self._send_tpu_core >= 0:
assert self._shape.is_fully_defined(), (
"TPU channel must have fully defined shape. Name: %s, shape: %s" %
(self._name, self._shape))
assert self._send_tpu_core != self._recv_tpu_core, (
"TPU send/recv must be cross-core: %s and %s" %
(send_device, recv_device))
[docs] def Send(self, tensor):
"""Sends a tensor through the channel."""
assert tensor.dtype == self._dtype
assert not self._send_called, ("Send called multiple times for %s" %
self._name)
self._send_called = True
if self._send_tpu_core == -1:
return tf.raw_ops.Send(
tensor=tensor,
tensor_name=self._name,
send_device=self._send_device,
send_device_incarnation=0,
recv_device=self._recv_device)
else:
with tf.device(self._send_device):
return xla.send(
tensor, tensor_name=self._name, name="Send_" + self._name)
[docs] def Recv(self):
"""Receives a tensor from the channel."""
if self._send_tpu_core == -1:
received = tf.raw_ops.Recv(
tensor_type=self._dtype,
tensor_name=self._name,
send_device=self._send_device,
send_device_incarnation=0,
recv_device=self._recv_device)
received.set_shape(self._shape)
return received
else:
with tf.device(self._recv_device):
return xla.recv(
self._dtype,
tensor_name=self._name,
shape=self._shape,
name="Recv_" + self._name)