# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools for car beam pipelines."""
import apache_beam as beam
[docs]def BeamInit():
"""Initialize the beam program.
Typically first thing to run in main(). This call is needed before FLAGS
are accessed, for example.
"""
pass
[docs]def GetPipelineRoot(options=None):
"""Return the root of the beam pipeline.
Typical usage looks like:
with GetPipelineRoot() as root:
_ = (root | beam.ParDo() | ...)
In this example, the pipeline is automatically executed when the context is
exited, though one can manually run the pipeline built from the root object as
well.
Args:
options: A beam.options.pipeline_options.PipelineOptions object.
Returns:
A beam.Pipeline root object.
"""
return beam.Pipeline(options=options)
[docs]def GetReader(record_format, file_pattern, value_coder, **kwargs):
"""Returns a beam Reader based on record_format and file_pattern.
Args:
record_format: String record format, e.g., 'tfrecord'.
file_pattern: String path describing files to be read.
value_coder: Coder to use for the values of each record.
**kwargs: arguments to pass to the corresponding Reader object constructor.
Returns:
A beam reader object.
Raises:
ValueError: If an unsupported record_format is provided.
"""
if record_format == "tfrecord":
return beam.io.ReadFromTFRecord(file_pattern, coder=value_coder, **kwargs)
raise ValueError("Unsupported record format: {}".format(record_format))
[docs]def GetWriter(record_format, file_pattern, value_coder, **kwargs):
"""Returns a beam Writer.
Args:
record_format: String record format, e.g., 'tfrecord' to write as.
file_pattern: String path describing files to be written to.
value_coder: Coder to use for the values of each written record.
**kwargs: arguments to pass to the corresponding Writer object constructor.
Returns:
A beam writer object.
Raises:
ValueError: If an unsupported record_format is provided.
"""
if record_format == "tfrecord":
return beam.io.WriteToTFRecord(file_pattern, coder=value_coder, **kwargs)
raise ValueError("Unsupported record format: {}".format(record_format))
[docs]def GetEmitterFn(record_format):
"""Returns an Emitter function for the given record_format.
An Emitter function takes in a key and value as arguments and returns
a structure that is compatible with the Beam Writer associated with
the corresponding record_format.
Args:
record_format: String record format, e.g., 'tfrecord' to write as.
Returns:
An emitter function of (key, value) -> Writer's input type.
Raises:
ValueError: If an unsupported record_format is provided.
"""
def _ValueEmitter(key, value):
del key
return [value]
if record_format == "tfrecord":
return _ValueEmitter
raise ValueError("Unsupported record format: {}".format(record_format))