• Docs >
  • SSD Embedding Operators
Shortcuts

SSD Embedding Operators

CUDA Operators

Tensor masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count)

Similar to torch.Tensor.index_put but ignore indices < 0

masked_index_put_cuda only supports 2D input values. It puts count rows in values into self using the row indices that are >= 0 in indices.

# Equivalent PyTorch Python code
indices = indices[:count]
filter_ = indices >= 0
indices_ = indices[filter_]
self[indices_] = values[filter_.nonzero().flatten()]
Parameters:
  • self – The 2D output tensor (the tensor that is indexed)

  • indices – The 1D index tensor

  • values – The 2D input tensor

  • count – The tensor that contains the length of indices to process

Returns:

The self tensor

Tensor masked_index_select_cuda(Tensor self, Tensor indices, Tensor values, Tensor count)

Similar to torch.index_select but ignore indices < 0

masked_index_select_cuda only supports 2D input values. It puts count rows that are specified in indices (where indices >= 0) from values into self

# Equivalent PyTorch Python code
indices = indices[:count]
filter_ = indices >= 0
indices_ = indices[filter_]
self[filter_.nonzero().flatten()] = values[indices_]
Parameters:
  • self – The 2D output tensor

  • indices – The 1D index tensor

  • values – The 2D input tensor (the tensor that is indexed)

  • count – The tensor that contains the length of indices to process

Returns:

The self tensor

std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(const Tensor &lxu_cache_locations, const Tensor &assigned_cache_slots, const Tensor &linear_index_inverse_indices, const Tensor &unique_indices_count_cumsum, const Tensor &cache_set_inverse_indices, const Tensor &lxu_cache_weights, const Tensor &inserted_ssd_weights, const Tensor &unique_indices_length, const Tensor &cache_set_sorted_unique_indices)

Generate memory addresses for SSD TBE data.

The data retrieved from SSD can be stored in either a scratch pad (HBM) or LXU cache (also HBM). lxu_cache_locations is used to specify the location of the data. If the location is -1, the data for the associated index is in the scratch pad; otherwise, it is in the cache. To enable TBE kernels to access the data conveniently, this operator generates memory addresses of the first byte for each index. When accessing data, a TBE kernel only needs to convert addresses into pointers.

Moreover, this operator also generate the list of post backward evicted indices which are basically the indices that their data is in the scratch pad.

Parameters:
  • lxu_cache_locations – The tensor that contains cache slots where data is stored for the full list of indices. -1 is a sentinel value that indicates that data is not in cache.

  • assigned_cache_slots – The tensor that contains cache slots for the unique list of indices. -1 indicates that data is not in cache

  • linear_index_inverse_indices – The tensor that contains the original position of linear indices before being sorted

  • unique_indices_count_cumsum – The tensor that contains the the exclusive prefix sum results of the counts of unique indices

  • cache_set_inverse_indices – The tensor that contains the original positions of cache sets before being sorted

  • lxu_cache_weights – The LXU cache tensor

  • inserted_ssd_weights – The scratch pad tensor

  • unique_indices_length – The tensor that contains the number of unique indices (GPU tensor)

  • cache_set_sorted_unique_indices – The tensor that contains associated unique indices for the sorted unique cache sets

Returns:

A tuple of tensors (the SSD row address tensor and the post backward evicted index tensor)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources