# Lint as: python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train models on KITTI data."""
import os
from lingvo import compat as tf
from lingvo import model_registry
from lingvo.core import base_model_params
from lingvo.core import cluster_factory
from lingvo.core import datasource
from lingvo.core import optimizer
from lingvo.core import py_utils
from lingvo.tasks.car import input_preprocessors
from lingvo.tasks.car import kitti_input_generator
from lingvo.tasks.car import lr_util
from lingvo.tasks.car import starnet
import numpy as np
# Set $KITTI_DIR to the base path of where all the KITTI files can be found.
#
# E.g., 'gs://your-bucket/kitti/3d'
_KITTI_BASE = os.environ.get('KITTI_DIR', 'FILL-ME-IN')
# Specifications for the different dataset splits.
[docs]def KITTITrainSpec(params):
p = params.Copy()
p.file_datasource.file_pattern = (
'kitti_object_3dop_train.tfrecord-*-of-00100')
p.num_samples = 3712
return p
[docs]def KITTIValSpec(params):
p = params.Copy()
p.file_datasource.file_pattern = ('kitti_object_3dop_val.tfrecord-*-of-00100')
p.num_samples = 3769
return p
[docs]def KITTITestSpec(params):
p = params.Copy()
p.file_datasource.file_pattern = ('kitti_object_test.tfrecord-*-of-00100')
p.num_samples = 7518
return p
[docs]class KITTITrain(kitti_input_generator.KITTILaser):
"""KITTI train set with raw laser data."""
[docs] @classmethod
def Params(cls):
"""Defaults params."""
p = super().Params()
return KITTITrainSpec(p)
[docs]class KITTIValidation(kitti_input_generator.KITTILaser):
"""KITTI validation set with raw laser data."""
[docs] @classmethod
def Params(cls):
"""Defaults params."""
p = super().Params()
return KITTIValSpec(p)
[docs]class KITTITest(kitti_input_generator.KITTILaser):
"""KITTI test set with raw laser data."""
[docs] @classmethod
def Params(cls):
p = super().Params()
return KITTITestSpec(p)
[docs]class KITTIGridTrain(kitti_input_generator.KITTIGrid):
"""KITTI train set with grid laser data."""
[docs] @classmethod
def Params(cls):
p = super().Params()
return KITTITrainSpec(p)
[docs]class KITTIGridValidation(kitti_input_generator.KITTIGrid):
"""KITTI validation set with grid laser data."""
[docs] @classmethod
def Params(cls):
p = super().Params()
return KITTIValSpec(p)
[docs]class KITTIGridTest(kitti_input_generator.KITTIGrid):
"""KITTI validation set with grid laser data."""
[docs] @classmethod
def Params(cls):
p = super().Params()
return KITTITestSpec(p)
[docs]class KITTISparseLaserTrain(kitti_input_generator.KITTISparseLaser):
"""KITTI train set with sparse laser data."""
[docs] @classmethod
def Params(cls):
p = super().Params()
return KITTITrainSpec(p)
[docs]class KITTISparseLaserValidation(kitti_input_generator.KITTISparseLaser):
"""KITTI validation set with sparse laser data."""
[docs] @classmethod
def Params(cls):
p = super().Params()
return KITTIValSpec(p)
[docs]class KITTISparseLaserTest(kitti_input_generator.KITTISparseLaser):
"""KITTI test set with sparse laser data."""
[docs] @classmethod
def Params(cls):
p = super().Params()
return KITTITestSpec(p)
[docs]def _MaybeRemove(values, key):
"""Remove the entry 'key' from 'values' if present."""
if key in values:
values.remove(key)
[docs]def AddLaserAndCamera(params):
"""Adds laser and camera extractors."""
cluster = cluster_factory.Current()
job = cluster.job
if job != 'decoder':
return params
extractor_params = list(dict(params.extractors.IterParams()).values())
extractor_classes = [p.cls for p in extractor_params]
# Add images if not present.
if kitti_input_generator.KITTIImageExtractor not in extractor_classes:
params.extractors.Define('images',
kitti_input_generator.KITTIImageExtractor.Params(),
'')
# Add raw lasers if not present.
if kitti_input_generator.KITTILaserExtractor not in extractor_classes:
labels = None
for p in extractor_params:
if p.cls == kitti_input_generator.KITTILabelExtractor:
labels = p
if labels is None:
labels = kitti_input_generator.KITTILabelExtractor.Params()
params.extractors.Define(
'lasers', kitti_input_generator.KITTILaserExtractor.Params(labels), '')
return params
################################################################################
# StarNet
################################################################################
[docs]@model_registry.RegisterSingleTaskModel
class StarNetCarsBase(base_model_params.SingleTaskModelParams):
"""StarNet model for cars."""
RUN_LOCALLY = False
NUM_ANCHOR_BBOX_OFFSETS = 25
NUM_ANCHOR_BBOX_ROTATIONS = 4
NUM_ANCHOR_BBOX_DIMENSIONS = 1
FOREGROUND_ASSIGNMENT_THRESHOLD = 0.6
BACKGROUND_ASSIGNMENT_THRESHOLD = 0.45
INCLUDED_CLASSES = ['Car']
[docs] class AnchorBoxSettings(input_preprocessors.SparseCarV1AnchorBoxSettings):
ROTATIONS = [0, np.pi / 2, 3. * np.pi / 4, np.pi / 4]
[docs] def Train(self):
p = KITTISparseLaserTrain.Params()
self._configure_input(p)
return p
[docs] def Test(self):
p = KITTISparseLaserTest.Params()
self._configure_input(p)
return p
[docs] def Dev(self):
p = KITTISparseLaserValidation.Params()
self._configure_input(p)
return p
[docs] def Task(self):
num_classes = len(
kitti_input_generator.KITTILabelExtractor.KITTI_CLASS_NAMES)
p = starnet.ModelV2.Params(
num_classes,
num_anchor_bboxes_offsets=self.NUM_ANCHOR_BBOX_OFFSETS,
num_anchor_bboxes_rotations=self.NUM_ANCHOR_BBOX_ROTATIONS,
num_anchor_bboxes_dimensions=self.NUM_ANCHOR_BBOX_DIMENSIONS)
p.name = 'sparse_detector'
tp = p.train
tp.optimizer = optimizer.Adam.Params()
tp.clip_gradient_norm_to_value = 5
ep = p.eval
# Evaluate the whole dataset.
ep.samples_per_summary = 0
# To be tuned.
p.train.l2_regularizer_weight = 1e-4
# Adapted from V1 tuning.
tp.ema_decay = 0.99
# TODO(b/148537111): consider setting this to True.
tp.ema_decay_moving_vars = False
tp.learning_rate = 0.001
lr_util.SetExponentialLR(
train_p=tp,
train_input_p=self.Train(),
exp_start_epoch=150,
total_epoch=650)
p.dimension_loss_weight = .3
p.location_loss_weight = 3.
p.loss_weight_classification = 1.
p.loss_weight_localization = 3.
p.rotation_loss_weight = 0.3
return p
[docs]@model_registry.RegisterSingleTaskModel
class StarNetCarModel0701(StarNetCarsBase):
"""StarNet Car model trained on KITTI."""
[docs] class AnchorBoxSettings(input_preprocessors.SparseCarV1AnchorBoxSettings):
CENTER_X_OFFSETS = np.linspace(-1.294, 1.294, 5)
CENTER_Y_OFFSETS = np.linspace(-1.294, 1.294, 5)
[docs] def Task(self):
p = super().Task()
# Builder configuration.
builder = starnet.Builder()
builder.linear_params_init = py_utils.WeightInit.KaimingUniformFanInRelu()
gin_layer_sizes = [32, 256, 512, 256, 256, 128]
num_laser_features = 1
gin_layers = [
# Each layer should expect as input - 2 * dims of the last layer's
# output. We assume a middle layer that's the size of 2 * dim_out.
[dim_in * 2, dim_out * 2, dim_out]
for (dim_in, dim_out) in zip(gin_layer_sizes[:-1], gin_layer_sizes[1:])
]
p.cell_feature_dims = sum(gin_layer_sizes)
p.cell_featurizer = builder.GINFeaturizerV2(
name='feat',
fc_dims=gin_layer_sizes[0],
mlp_dims=gin_layers,
num_laser_features=num_laser_features,
fc_use_bn=False)
p.anchor_projected_feature_dims = 512
# Loss and training params
p.train.learning_rate = 0.001 / 2. # Divide by batch size.
p.focal_loss_alpha = 0.2
p.focal_loss_gamma = 3.0
class_name_to_idx = kitti_input_generator.KITTILabelExtractor.KITTI_CLASS_NAMES
num_classes = len(class_name_to_idx)
p.per_class_loss_weight = [0.] * num_classes
p.per_class_loss_weight[class_name_to_idx.index('Car')] = 1.
# Decoding / NMS params.
p.use_oriented_per_class_nms = True
p.max_nms_boxes = 512
p.nms_iou_threshold = [0.0] * num_classes
p.nms_iou_threshold[class_name_to_idx.index('Car')] = 0.0831011
p.nms_score_threshold = [1.0] * num_classes
p.nms_score_threshold[class_name_to_idx.index('Car')] = 0.321310
p.output_decoder.truncation_threshold = 0.65
p.output_decoder.filter_predictions_outside_frustum = True
return p
[docs]@model_registry.RegisterSingleTaskModel
class StarNetPedCycModel0704(StarNetCarsBase):
"""StarNet Ped/Cyc model trained on KITTI."""
INCLUDED_CLASSES = ['Pedestrian', 'Cyclist']
FOREGROUND_ASSIGNMENT_THRESHOLD = 0.48
# Any value > FOREGROUND is equivalent.
BACKGROUND_ASSIGNMENT_THRESHOLD = 0.80
NUM_ANCHOR_BBOX_OFFSETS = 9
NUM_ANCHOR_BBOX_ROTATIONS = 4
NUM_ANCHOR_BBOX_DIMENSIONS = 3
[docs] class AnchorBoxSettings(input_preprocessors.SparseCarV1AnchorBoxSettings):
# PointPillars priors for pedestrian/cyclists.
DIMENSION_PRIORS = [(0.6, 0.8, 1.7), (0.6, 0.6, 1.2), (0.6, 1.76, 1.73)]
ROTATIONS = [0, np.pi / 2, 3. * np.pi / 4, np.pi / 4]
CENTER_X_OFFSETS = np.linspace(-0.31, 0.31, 3)
CENTER_Y_OFFSETS = np.linspace(-0.31, 0.31, 3)
CENTER_Z_OFFSETS = [-0.6]
[docs] def Task(self):
p = super().Task()
p.train.learning_rate = 7e-4
builder = starnet.Builder()
builder.linear_params_init = py_utils.WeightInit.KaimingUniformFanInRelu()
gin_layer_sizes = [32, 256, 512, 256, 256, 128]
num_laser_features = 1
gin_layers = [
# Each layer should expect as input - 2 * dims of the last layer's
# output. We assume a middle layer that's the size of 2 * dim_out.
[dim_in * 2, dim_out * 2, dim_out]
for (dim_in, dim_out) in zip(gin_layer_sizes[:-1], gin_layer_sizes[1:])
]
p.cell_feature_dims = sum(gin_layer_sizes)
# Disable BN on first layer
p.cell_featurizer = builder.GINFeaturizerV2(
'feat',
gin_layer_sizes[0],
gin_layers,
num_laser_features,
fc_use_bn=False)
p.anchor_projected_feature_dims = 512
class_name_to_idx = kitti_input_generator.KITTILabelExtractor.KITTI_CLASS_NAMES
num_classes = len(class_name_to_idx)
p.per_class_loss_weight = [0.] * num_classes
p.per_class_loss_weight[class_name_to_idx.index('Pedestrian')] = 3.5
p.per_class_loss_weight[class_name_to_idx.index('Cyclist')] = 3.25
p.focal_loss_alpha = 0.9
p.focal_loss_gamma = 1.25
p.use_oriented_per_class_nms = True
p.max_nms_boxes = 1024
p.nms_iou_threshold = [0.0] * num_classes
p.nms_iou_threshold[class_name_to_idx.index('Cyclist')] = 0.49
p.nms_iou_threshold[class_name_to_idx.index('Pedestrian')] = 0.32
p.nms_score_threshold = [1.0] * num_classes
p.nms_score_threshold[class_name_to_idx.index('Cyclist')] = 0.11
p.nms_score_threshold[class_name_to_idx.index('Pedestrian')] = 0.23
p.output_decoder.filter_predictions_outside_frustum = True
p.output_decoder.truncation_threshold = 0.65
# Equally weight pedestrian and cyclist moderate classes.
p.output_decoder.ap_metric.metric_weights = {
'easy': np.array([0.0, 0.0, 0.0]),
'moderate': np.array([0.0, 1.0, 1.0]),
'hard': np.array([0.0, 0.0, 0.0])
}
return p