Source code for lingvo.datasets

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for dataset information."""

import ast
import inspect
import os
from typing import Any, List
from absl import logging

# List of member functions that are not dataset functions.
NON_DATASET_MEMBERS = [
    'GetAllDatasetParams', 'GetDatasetParams', 'Model', 'Search', 'Task',
    'ProgramSchedule', 'UpdateParamsFromSpec', 'CreateDynamicDatasetMethods'
]


[docs]class DatasetFunctionError(TypeError): pass
[docs]class GetAllDatasetParamsNotImplementedError(NotImplementedError): pass
[docs]def GetDatasets(cls: Any, warn_on_error: bool = True) -> List[str]: """Returns the list of dataset functions (e.g., Train, Dev, ...). All public functions apart from `NON_DATASET_MEMBERS` are treated as datasets. Dataset functions should not have any required positional arguments. Args: cls: A class variable or instance variable. This function expects to be called on classes that can be used as model tasks e.g. via model_registry.RegisterSingleTaskModel. warn_on_error: When a class contains public methods that cannot be used as a dataset, if True, logs a warning, if False, raises a DatasetFunctionError. Returns: A list of strings containing names of valid dataset functions for cls. Raises: DatasetFunctionError: if the cls contains public methods that cannot be used as datasets, and warn_on_error is False. """ mdl_params = None if inspect.isclass(cls): try: mdl_params = cls() except TypeError: # Capture cls construction error pass else: mdl_params = cls if mdl_params: try: all_datasets = mdl_params.GetAllDatasetParams() return sorted(list(all_datasets.keys())) except GetAllDatasetParamsNotImplementedError: pass datasets = [] for name, _ in inspect.getmembers(cls, inspect.isroutine): if name not in NON_DATASET_MEMBERS and not name.startswith('_'): # Datasets are assumed to have no required positional arguments. fn = getattr(cls, name) args = list(inspect.signature(fn).parameters.values()) if inspect.isclass(cls) and not inspect.ismethod(fn): # Methods obtained from inspecting a class includes a 'self' first # argument that should be ignored. That is because they are not bound. # Methods obtained from inspecting an instance, or classmethods obtained # from inspecting a class are bound and inspect.ismethod() returns True. args = args[1:] positional_arguments = [p.name for p in args if p.default == p.empty] if positional_arguments: if inspect.isclass(cls): class_name = cls.__name__ else: class_name = cls.__class__.__name__ message = (f'Found a public function {name} in {class_name} with ' f'required positional arguments: {positional_arguments}.') if warn_on_error: logging.warning(message) else: raise DatasetFunctionError(message) else: datasets += [name] return datasets
[docs]def GetDatasetsAst(base_dir: str, model: str) -> List[str]: """Gets datasets but without importing any code by using ast. Useful when running from python interpreter without bazel build. Args: base_dir: Base directory to search in. model: The model string. Returns: A list of strings containing names of valid dataset functions for cls. May not be accurate. Raises: Exception: if anything goes wrong. """ parts = model.split('.') model_name = parts[-1] parts = parts[:-1] module = os.path.join(base_dir, '/'.join(parts)) + '.py' for i in range(1, len(parts)): # Insert params somewhere in the middle of parts. test = os.path.join(base_dir, '/'.join(parts[:i]), 'params', '/'.join( parts[i:])) + '.py' if os.path.exists(test): module = test break with open(os.path.join(module), 'r') as f: tree = ast.parse(f.read()) class DatasetsVisitor(ast.NodeVisitor): """NodeVisitor for collecting datasets for a model.""" def __init__(self): self.datasets = set() self._imports = {} def visit_Import(self, node): # pylint: disable=invalid-name """Visit a 'import symbol [as alias]' definition.""" for alias in node.names: self._imports[alias.asname or alias.name] = alias.name def visit_ImportFrom(self, node): # pylint: disable=invalid-name """Visit a 'from module import symbol [as alias]' definition.""" for alias in node.names: self._imports[alias.asname or alias.name] = ( node.module + '.' + alias.name) def visit_ClassDef(self, node): # pylint: disable=invalid-name """Visit a class definition.""" if node.name == model_name: for base in node.bases: if isinstance(base, ast.Name): # A superclass in the same file. self.datasets |= set( GetDatasetsAst(base_dir, '.'.join(parts + [base.id]))) elif isinstance(base, ast.Attribute): # A superclass in a different file. if base.value.id == 'base_model_params': continue self.datasets |= set( GetDatasetsAst( base_dir, '.'.join([self._imports[base.value.id], base.attr]))) self.generic_visit(node) def visit_FunctionDef(self, node): # pylint: disable=invalid-name """Visit a function definition.""" if node.name == 'GetAllDatasetParams': # It may be possible to parse the ast for GetAllDatasetParams to find # the dictionary keys, but this gets significantly harder when super() # calls need to be taken into consideration. raise NotImplementedError( 'GetDatasetsAst does not support models using GetAllDatasetParams.') elif (node.name not in NON_DATASET_MEMBERS and not node.name.startswith('_')): self.datasets.add(node.name) visitor = DatasetsVisitor() visitor.visit(tree) return list(visitor.datasets)