lingvo.tasks.car.detection_3d_lib module

Library of useful for functions for working with 3D object detection.

class lingvo.tasks.car.detection_3d_lib.Utils3D[source]

Bases: object

Helper routines for 3D detection problems.

One common method to do 3D anchor box assignment is to match anchors to ground-truth bboxes. First, we generate proposal anchors at given priors (center locations, dimension prior, offset prior) to tile the input space. After tiling the input space, each anchor can be assigned to a ground-truth bbox by measuring IOU and taking a threshold. Note that a ground-truth bbox may be assigned to multiple anchors - this is expected, and is managed at inference time by non-max suppression.

Note: This implementation is designed to be used at input generation, and does not support a batch dimension.

The following functions in this utility class helps with that:

  • CreateDenseCoordinates: Makes it easy to create a dense grid of coordinates that would usually correspond to center locations.

  • MakeAnchorBoxes: Given a list of center coordinates, dimension priors, and offset priors, this function creates actual anchor bbox parameters at each coordinate. More than one box can be at each center.

  • IOUAxisAlignedBoxes: This function computes the IOU between two lists of boxes.

  • AssignAnchors: This function assigns each anchor a ground-truth bbox. Note that one ground-truth bbox can be assigned to multiple anchors. The model is expected to regress the residuals each anchor and it’s corresponding ground-truth bbox parameters.

ScaledHuberLoss(labels, predictions, weights=1.0, delta=1.0)[source]

Scaled Huber (SmoothL1) Loss.

This function wraps tf.losses.huber_loss to rescale it by 1 / delta, and uses Reduction.NONE.

This scaling results in the following formulation instead:

(1/d) * 0.5 * x^2       if \|x\| <= d
\|x\| - 0.5 * d         if \|x\| > d

where x is labels - predictions.

Hence, delta changes where the quadratic bowl is, but does not change the overall shape of the loss outside of delta.

Parameters
  • labels – The ground truth output tensor, same dimensions as ‘predictions’.

  • predictions – The predicted outputs.

  • weights – Optional Tensor whose rank is either 0, or the same rank as labels, and must be broadcastable to labels (i.e., all dimensions must be either 1, or the same as the corresponding losses dimension).

  • delta – float, the point where the huber loss function changes from a quadratic to linear.

Returns

Weighted loss float Tensor. This has the same shape as labels.

CornerLoss(gt_bboxes, predicted_bboxes, symmetric=True)[source]

Corner regularization loss.

This function computes the corner loss, an alternative regression loss for box residuals. This was used in the Frustum-PointNets paper [1].

We compute the predicted bboxes (all 8 corners) and compute a SmoothedL1 loss between the corners of the predicted boxes and ground truth. Hence, this loss can help encourage the model to maximize the IoU of the predictions.

[1] Frustum PointNets for 3D Object Detection from RGB-D Data

https://arxiv.org/pdf/1711.08488.pdf

Parameters
  • gt_bboxes – tf.float32 of shape […, 7] which contains (x, y, z, dx, dy, dz, phi), corresponding to ground truth bbox parameters.

  • predicted_bboxes – tf.float32 of same shape as gt_bboxes containing predicted bbox parameters.

  • symmetric – boolean. If True, computes the minimum of the corner loss with respect to both the gt box and the gt box rotated 180 degrees.

Returns

tf.float32 Tensor of shape […] where each entry contains the corner loss for the corresponding bbox.

CreateDenseCoordinates(ranges)[source]

Create a matrix of coordinate locations corresponding to a dense grid.

Example: To create (x, y) coordinates corresponding over a 10x10 grid with step sizes 1, call CreateDenseCoordinates([(1, 10, 10), (1, 10, 10)]).

Parameters

ranges – A list of 3-tuples, each tuple is expected to contain (min, max, num_steps). Each list element corresponds to one dimesion. Each tuple will be passed into np.linspace to create the values for a single dimension.

Returns

tf.float32 tensor of shape [total_points, len(ranges)], where total_points = product of all num_steps.

MakeAnchorBoxes(anchor_centers, anchor_box_dimensions, anchor_box_offsets, anchor_box_rotations=None)[source]

Create anchor boxes from centers, dimensions, offsets.

Parameters
  • anchor_centers – [A, dims] tensor. Center locations to generate boxes at.

  • anchor_box_dimensions – [B, dims] tensor corresponding to dimensions of each box. The inner-most dimension of this tensor must match anchor_centers.

  • anchor_box_offsets – [B, dims] tensor corresponding to offsets of each box.

  • anchor_box_rotations – [B] tensor corresponding to rotation of each box. If None, rotation will be set to 0.

Returns

A [num_anchors_center, num_boxes_per_center, 2 * dims + 1] tensor. Usually dims=3 for 3D, where [..., :dims] corresponds to location, [..., dims:2*dims] corresponds to dimensions, and [..., -1] corresponds to rotation.

IOU2DRotatedBoxes(bboxes_u, bboxes_v)[source]

Computes IoU between every pair of bboxes with headings.

This function ignores the z dimension, which is not usually considered during anchor assignment.

Parameters
  • bboxes_u – tf.float32. [U, dims]. […, :7] are (x, y, z, dx, dy, dz, r).

  • bboxes_v – tf.float32. [V, dims]. […, :7] are (x, y, z, dx, dy, dz, r).

Returns

tf.float32 tensor with shape [U, V], where [i, j] is IoU between

i-th bbox of bboxes_u and j-th bbox of bboxes_v.

AssignAnchors(anchor_bboxes, gt_bboxes, gt_bboxes_labels, gt_bboxes_mask, foreground_assignment_threshold=0.5, background_assignment_threshold=0.35, background_class_id=0, force_match=True, similarity_fn=None)[source]

Assigns anchors to bboxes using a similarity function (SSD-based).

Each anchor box is assigned to the top matching ground truth box. Ground truth boxes can be assigned to multiple anchor boxes.

Assignments can result in 3 outcomes:

  • Positive assignment (if score >= foreground_assignment_threshold): assigned_gt_labels will reflect the assigned box label and assigned_cls_mask will be set to 1.0

  • Background assignment (if score <= background_assignment_threshold): assigned_gt_labels will be background_class_id and assigned_cls_mask will be set to 1.0

  • Ignore assignment (otherwise): assigned_gt_labels will be background_class_id and assigned_cls_mask will be set to 0.0

The detection loss function would usually:

  • Use assigned_cls_mask for weighting the classification loss. The mask is set such that the loss applies to foreground and background assignments only - ignored anchors will be set to 0.

  • Use assigned_reg_mask for weighting the regression loss. The mask is set such that the loss applies to foreground assignments only.

The thresholds (foreground_assignment_threshold and background_assignment_threshold) should be tuned per dataset.

TODO(jngiam): Consider having a separate threshold for regression boxes; a separate threshold is used in PointRCNN.

Parameters
  • anchor_bboxes – tf.float32. [A, 7], where […, :] corresponds to box parameters (x, y, z, dx, dy, dz, r).

  • gt_bboxes – tf.float32. [G, 7], where […, :] corresponds to ground truth box parameters (x, y, z, dx, dy, dz, r).

  • gt_bboxes_labels – tensor with shape [G]. Ground truth labels for each bounding box.

  • gt_bboxes_mask – tensor with shape [G]. Mask for ground truth boxes, 1 iff the gt_bbox is a real bbox.

  • foreground_assignment_threshold – Similarity score threshold for assigning foreground bounding boxes; scores need to be >= foreground_assignment_threshold to be assigned to foreground.

  • background_assignment_threshold – Similarity score threshold for assigning background bounding boxes; scores need to be <= background_assignment_threshold to be assigned to background.

  • background_class_id – class id to be assigned to anchors_gt_class if no anchor boxes match.

  • force_match – Boolean specifying if force matching is enabled. If force matching is enabled, then matched anchors which are also the highest scoring with a ground-truth box are considered foreground matches as long as their similarity score > 0.

  • similarity_fn – Function that computes the a similarity score (e.g., IOU) between pairs of bounding boxes. This function should take in two tensors corresponding to anchor and ground-truth bboxes, and return a matrix [A, G] with the similarity score between each pair of bboxes. The score must be non-negative, with greater scores representing more similar. The fore/background_assignment_thresholds will be applied to this score to determine if the an anchor is foreground, background or ignored. If set to None, the function will default to IOU2DRotatedBoxes.

Returns

NestedMap with the following keys

  • assigned_gt_idx: shape [A] index corresponding to the index of the assigned ground truth box. Anchors not assigned to a ground truth box will have the index set to -1.

  • assigned_gt_bbox: shape [A, 7] bbox parameters assigned to each anchor.

  • assigned_gt_similarity_score: shape [A] (iou) score between the anchor and the gt bbox.

  • assigned_gt_labels: shape [A] label assigned to bbox.

  • assigned_cls_mask: shape [A] mask for classification loss per anchor. This should be 1.0 if the anchor has a foreground or background assignment; otherwise, it will be assigned to 0.0.

  • assigned_reg_mask: shape [A] mask for regression loss per anchor. This should be 1.0 if the anchor has a foreground assignment; otherwise, it will be assigned to 0.0. Note: background anchors do not have regression targets.

LocalizationResiduals(anchor_bboxes, assigned_gt_bboxes)[source]

Computes the anchor residuals for every bbox.

For a given bbox, compute residuals in the following way:

Let anchor_bbox = (x_a, y_a, z_a, dx_a, dy_a, dz_a, phi_a) and assigned_gt_bbox = (x_gt, y_gt, z_gt, dx_gt, dy_gt, dz_gt, phi_gt)

Define diagonal_xy = sqrt(dx_a^2 + dy_a^2)

Then the corresponding residuals are given by:

x_residual = (x_gt - x_a) / (diagonal_xy)
y_residual = (y_gt - y_a) / (diagonal_xy)
z_residual = (z_gt - z_a) / (dz_a)

dx_residual = log(dx_gt / dx_a)
dy_residual = log(dy_gt / dy_a)
dz_residual = log(dz_gt / dz_a)

phi_residual = phi_gt - phi_a

The normalization for x and y residuals by the diagonal was first proposed by [1]. Intuitively, this reflects that objects can usually move freely in the x-y plane, including diagonally. On the other hand, moving in the z-axis (up and down) can be considered orthogonal to x-y.

For phi_residual, one way to frame the loss is with SmoothL1(sine(phi_residual - phi_predicted)). The use of sine to wrap the phi residual was proposed by [2]. This stems from the observation that bboxes at phi and phi + pi are the same bbox, fully overlapping in 3D space, except that the direction is different. Note that the use of sine makes this residual invariant to direction when a symmetric loss like SmoothL1 is used. In ResidualsToBBoxes, we ensure that the phi predicted is between [0, pi).

The Huber (SmoothL1) loss can then be applied to the delta between these target residuals and the model predicted residuals.

[1] VoxelNet: End-to-End Learning for Point Cloud Based 3D Object Detection

https://arxiv.org/abs/1711.06396

[2] SECOND: Sparsely Embedded Convolutional Detection

https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf

Parameters
  • anchor_bboxes – tf.float32. where […, :7] contains (x, y, z, dx, dy, dz, phi), corresponding to each anchor bbox parameters.

  • assigned_gt_bboxes – tf.float32 of the same shape as anchor_bboxes containing the corresponding assigned ground-truth bboxes.

Returns

A tf.float32 tensor of the same shape as anchor_bboxes with target residuals for every corresponding bbox.

ResidualsToBBoxes(anchor_bboxes, residuals, min_angle_rad=- 3.141592653589793, max_angle_rad=3.141592653589793)[source]

Converts anchor_boxes and residuals to predicted bboxes.

This converts predicted residuals into bboxes using the following formulae:

x_predicted = x_a + x_residual * diagonal_xy
y_predicted = y_a + y_residual * diagonal_xy
z_predicted = z_a + z_residual * dz_a

dx_predicted = dx_a * exp(dx_residual)
dy_predicted = dy_a * exp(dy_residual)
dz_predicted = dz_a * exp(dz_residual)

# Adding the residual, and bounding it between
# [min_angle_rad, max_angle_rad]
phi_predicted = NormalizeAngleRad(phi_a + phi_residual,
                                  min_angle_rad, max_angle_rad)

These equations follow from those in LocalizationResiduals, where we solve for the *_gt variables.

Parameters
  • anchor_bboxes – tf.float32. where […, :7] contains (x, y, z, dx, dy, dz, phi), corresponding to each anchor bbox parameters.

  • residuals – tf.float32 of the same shape as anchor_bboxes containing predicted residuals at each anchor location.

  • min_angle_rad – Scalar with the minimum angle allowed (before wrapping) in radians.

  • max_angle_rad – Scalar with the maximum angle allowed (before wrapping) in radians. This value usually should be pi.

Returns

A tf.float32 tensor of the same shape as anchor_bboxes with predicted bboxes.

NMSIndices(bboxes, scores, max_output_size, nms_iou_threshold=0.3, score_threshold=0.01)[source]

Apply NMS to a series of 3d bounding boxes in 7-DOF format.

Parameters
  • bboxes – A [num_boxes, 7] floating point Tensor of bounding boxes in [x, y, z, dx, dy, dz, phi] format.

  • scores – A [num_boxes] floating point Tensor containing box scores.

  • max_output_size – Maximum number of boxes to predict per input.

  • nms_iou_threshold – IoU threshold to use when determining whether two boxes overlap for purposes of suppression.

  • score_threshold – The score threshold passed to NMS that allows NMS to quickly ignore irrelevant boxes.

Returns

The NMS indices and the mask of the padded indices.

BatchedNMSIndices(bboxes, scores, nms_iou_threshold=0.3, score_threshold=0.01, max_num_boxes=None)[source]

Batched version of NMSIndices.

Parameters
  • bboxes – A [batch_size, num_boxes, 7] floating point Tensor of bounding boxes in [x, y, z, dx, dy, dz, phi] format.

  • scores – A [batch_size, num_boxes, num_classes] floating point Tensor containing box scores.

  • nms_iou_threshold – IoU threshold to use when determining whether two boxes overlap for purposes of suppression.

  • score_threshold – The score threshold passed to NMS that allows NMS to quickly ignore irrelevant boxes.

  • max_num_boxes – The maximum number of boxes per example to emit. If None, this value is set to num_boxes from the shape of bboxes.

Returns

The NMS indices and the mask of the padded indices for each example in the batch.

BatchedOrientedNMSIndices(bboxes, scores, nms_iou_threshold, score_threshold, max_boxes_per_class)[source]

Runs batched version of a Per-Class 3D (7-DOF) Non Max Suppression.

All outputs have shape [batch_size, num_classes, max_boxes_per_class].

Parameters
  • bboxes – A [batch_size, num_boxes, 7] floating point Tensor of bounding boxes in [x, y, z, dx, dy, dz, phi] format.

  • scores – A [batch_size, num_boxes, num_classes] floating point Tensor containing box scores.

  • nms_iou_threshold – Either a float or a list of floats of len num_classes with the IoU threshold to use when determining whether two boxes overlap for purposes of suppression.

  • score_threshold – Either a float or a list of floats of len num_classes with the score threshold that allows NMS to quickly ignore boxes.

  • max_boxes_per_class – An integer scalar with the maximum number of boxes per example to emit per class.

Returns

  • bbox_indices: An int32 Tensor with the indices of the chosen boxes. Values are in sort order until the class_idx switches.

  • bbox_scores: A float32 Tensor with the score for each box.

  • valid_mask: A float32 Tensor with 1/0 values indicating the validity of each box. 1 indicates valid, and 0 invalid.

Return type

A tuple of 3 tensors

CornersToImagePlane(corners, velo_to_image_plane)[source]

Project 3d box corners to the image plane.

Parameters
  • corners – A [batch, num_boxes, 8, 3] floating point tensor containing the 8 corners points for each 3d bounding box.

  • velo_to_image_plane – A [batch, 3, 4] batch set of projection matrices from velo xyz to image plane xy. After multiplication, you need to divide by last coordinate to recover 2D pixel locations.

Returns

A [batch, num_boxes, 8, 2] floating point Tensor containing the 3D bounding box corners projected to the image plane.

lingvo.tasks.car.detection_3d_lib.RandomPadOrTrimTo(tensor_list, num_points_out, seed=None)[source]

Pads or Trims a list of Tensors on the major dimension.

Slices if there are more points, or pads if not enough.

In this implementation:

Padded points are random duplications of real points. Sliced points are a random subset of the real points.

Parameters
  • tensor_list – A list of tf.Tensor objects to pad or trim along first dim. All tensors are expected to have the same first dimension.

  • num_points_out – An int for the requested number of points to trim/pad to.

  • seed – Random seed to use for random generators.

Returns

A tuple of output_tensors and a padding indicator.

  • output_tensors: A list of padded or trimmed versions of our tensor_list input tensors, all with the same first dimension.

  • padding: A tf.float32 tf.Tensor of shape [num_points_out] with 0 if the point is real, 1 if it is padded.