carabiner.tf package
Submodules
carabiner.tf.utils module
- carabiner.tf.utils.get_param(dim0: Tensor, dim1: Tensor, m: Tensor) Tensor[source]
Matrix multiply two dense matrix tensors with a 2d dense tensor.
The 2d dense tensor m is a [batch x m x p] tensor. The first matrix dim0 multiplies along m, and the second matrix dim1 multiplies along p.
- Parameters:
dim0 (Tensor) – Indicator matrix tensor.
dim1 (Tensor) – Indicator matrix tensor.
m (Tensor) – Dense tensor.
- Return type:
Tensor
- carabiner.tf.utils.get_param_sparse(dim0: Tensor, dim1: Tensor, m: Tensor) Tensor[source]
Matrix multiply two indicator matrix tensors with a 2d dense tensor.
The indicator tensor is a [batch x n x 1] tensor of indices indicating the single value in a row that is set to 1.
The 2d dense tensor m is a [batch x m x p] tensor. The first indicator dim0 indexes into m, and the second indicator matrix dim1 indexes into p.
- Parameters:
dim0 (Tensor) – Indicator matrix tensor.
dim1 (Tensor) – Indicator matrix tensor.
m (Tensor) – Dense tensor.
- Return type:
Tensor
- carabiner.tf.utils.sparse_matmul(a: Tensor, b: Tensor) Tensor[source]
Matrix multiply an indicator matrix tensor with a dense tensor.
The indicator tensor is a [batch x n x 1] tensor of indices indicating the single value in a row that is set to 1.
- Parameters:
a (Tensor) – Indicator matrix tensor.
b (Tensor) – Dense tensor.
- Return type:
Tensor
- carabiner.tf.utils.sparse_matmul_t(a: Tensor, b: Tensor) Tensor[source]
Matrix multiply an indicator matrix tensor with the transpose of a dense tensor.
The indicator tensor is a [batch x n x 1] tensor of indices indicating the single value in a row that is set to 1.
This should be more efficient than explicitly transposing the dense tensor.
- Parameters:
a (Tensor) – Indicator matrix tensor.
b (Tensor) – Dense tensor to transpose.
- Return type:
Tensor