carabiner.tf package

Submodules

carabiner.tf.utils module

carabiner.tf.utils.get_param(dim0: Tensor, dim1: Tensor, m: Tensor) Tensor[source]

Matrix multiply two dense matrix tensors with a 2d dense tensor.

The 2d dense tensor m is a [batch x m x p] tensor. The first matrix dim0 multiplies along m, and the second matrix dim1 multiplies along p.

Parameters:
  • dim0 (Tensor) – Indicator matrix tensor.

  • dim1 (Tensor) – Indicator matrix tensor.

  • m (Tensor) – Dense tensor.

Return type:

Tensor

carabiner.tf.utils.get_param_sparse(dim0: Tensor, dim1: Tensor, m: Tensor) Tensor[source]

Matrix multiply two indicator matrix tensors with a 2d dense tensor.

The indicator tensor is a [batch x n x 1] tensor of indices indicating the single value in a row that is set to 1.

The 2d dense tensor m is a [batch x m x p] tensor. The first indicator dim0 indexes into m, and the second indicator matrix dim1 indexes into p.

Parameters:
  • dim0 (Tensor) – Indicator matrix tensor.

  • dim1 (Tensor) – Indicator matrix tensor.

  • m (Tensor) – Dense tensor.

Return type:

Tensor

carabiner.tf.utils.sparse_matmul(a: Tensor, b: Tensor) Tensor[source]

Matrix multiply an indicator matrix tensor with a dense tensor.

The indicator tensor is a [batch x n x 1] tensor of indices indicating the single value in a row that is set to 1.

Parameters:
  • a (Tensor) – Indicator matrix tensor.

  • b (Tensor) – Dense tensor.

Return type:

Tensor

carabiner.tf.utils.sparse_matmul_t(a: Tensor, b: Tensor) Tensor[source]

Matrix multiply an indicator matrix tensor with the transpose of a dense tensor.

The indicator tensor is a [batch x n x 1] tensor of indices indicating the single value in a row that is set to 1.

This should be more efficient than explicitly transposing the dense tensor.

Parameters:
  • a (Tensor) – Indicator matrix tensor.

  • b (Tensor) – Dense tensor to transpose.

Return type:

Tensor

Module contents