Attention

发布时间 2023-08-24 11:17:23作者: 信海
#include <type_traits>
#include <limits>
#include <c10/core/DeviceType.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIndexing.h>
#include <ATen/cpu/vec/vec256/vec256.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <utility>
#include <c10/util/typeid.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/Logging.h>
#include <c10/util/Exception.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <ATen/TensorSubclassLikeUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cat.h>
#endif

#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
namespace at {

namespace native {

DEFINE_DISPATCH(_fused_sdp_choice_stub);
REGISTER_NO_CPU_DISPATCH(_fused_sdp_choice_stub);

namespace {

Tensor gemm_nt(const Tensor& self, const Tensor& other) {
  if (self.is_nested()) {
    return NestedTensor_matmul(self, other.t());
  } else {
    return at::native::matmul(self, other.t());
  }
}

template <typename scalar_t>
void transform_bias_rescale_qkv_inner_loop(
    int64_t B,
    int64_t T,
    int64_t _3D,
    int64_t D,
    int64_t num_head,
    int64_t dim_per_head,
    scalar_t* qkv_data,
    scalar_t* qkv_bias_data,
    scalar_t* q_k_v_data,
    scalar_t inv_sqrt_dim_per_head,
    int64_t begin,
    int64_t end) {
  for (auto i : c10::irange(begin, end)) {
    auto t = i % T;
    i /= T;
    auto nh = i % num_head;
    i /= num_head;
    auto b = i;
    using Vec = vec::Vectorized<scalar_t>;
    auto V = vec::Vectorized<scalar_t>::size();
    auto dh = 0;
    auto d = nh * dim_per_head;
    for (; dh + V <= dim_per_head; dh += V, d += V) {
      // load
      auto q_bias_data = Vec::loadu(&qkv_bias_data[d + 0 * D]);
      auto k_bias_data = Vec::loadu(&qkv_bias_data[d + 1 * D]);
      auto v_bias_data = Vec::loadu(&qkv_bias_data[d + 2 * D]);

      auto q_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 0 * D]) +
          q_bias_data;
      auto k_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 1 * D]) +
          k_bias_data;
      auto v_data = Vec::loadu(&qkv_data[b * _3D * T + t * _3D + d + 2 * D]) +
          v_bias_data;

      q_data = q_data * Vec(inv_sqrt_dim_per_head);

      q_data.store(&q_k_v_data
                       [0 * B * num_head * T * dim_per_head +
                        b * num_head * T * dim_per_head +
                        nh * T * dim_per_head + t * dim_per_head + dh]);
      k_data.store(&q_k_v_data
                       [1 * B * num_head * T * dim_per_head +
                        b * num_head * T * dim_per_head +
                        nh * T * dim_per_head + t * dim_per_head + dh]);
      v_data.store(&q_k_v_data
                       [2 * B * num_head * T * dim_per_head +
                        b * num_head * T * dim_per_head +
                        nh * T * dim_per_head + t * dim_per_head + dh]);
    }
    for (; dh < dim_per_head; dh++) {
      auto d = nh * dim_per_head + dh;
      auto q_bias = qkv_bias_data[d + 0 * D];
      auto k_bias = qkv_bias_data[d + 1 * D];
      auto v_bias = qkv_bias_data[d + 2 * D];
      auto q_data = qkv_data[b * _3D * T + t * _3D + d + 0 * D] + q_bias;
      auto k_data = qkv_data[b * _3D * T + t * _3D + d + 1 * D] + k_bias;
      auto v_data = qkv_data[b * _3D * T + t * _3D + d + 2 * D] + v_bias;
      q_data = q_data * inv_sqrt_dim_per_head;
      q_k_v_data
          [0 * B * num_head * T * dim_per_head +
           b * num_head * T * dim_per_head + nh * T * dim_per_head +
           t * dim_per_head + dh] = q_data;
      q_k_v_data
          [1 * B * num_head * T * dim_per_head +
           b * num_head * T * dim_per_head + nh * T * dim_per_head +
           t * dim_per_head + dh] = k_data;
      q_k_v_data
          [2 * B * num_head * T * dim_per_head +
           b * num_head * T * dim_per_head + nh * T * dim_per_head +
           t * dim_per_head + dh] = v_data;
    }
  }
}

Tensor transform_0213(const Tensor& a) {
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(1));
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(a.size(3));
  return a.permute({0, 2, 1, 3})
      .contiguous()
      .view({a.size(0), a.size(2), a.size(1) * a.size(3)});
}

} // namespace


Tensor bmm_nt(const Tensor& a, const Tensor& b) {
  // a,b: B * H * T * Embed
  auto a_ = a.view({a.size(0) * a.size(1), a.size(2), a.size(3)});  // 重构为(B*H) * T * Embed
  auto b_ = b.view({b.size(0) * b.size(1), b.size(2), b.size(3)});  // 重构为(B*H) * T * Embed
  auto bt_ = b_.transpose(2, 1); // (B*H) * Embed * T
  auto c_ = at::bmm(a_, bt_); // (B*H) * T * T
  return c_.view({a.size(0), a.size(1), a.size(2), b.size(2)});  // B * H * T * T
}

Tensor masked_softmax(
    Tensor& attn_scores,
    c10::optional<Tensor> attn_mask,
    const Tensor& query,
    c10::optional<int64_t> mask_type) {
  if (query.is_nested() && !attn_mask) {
    return at::_nested_tensor_softmax_with_shape(attn_scores, query);
  }
  if (attn_mask && attn_mask->dtype() != at::kBool) {
    attn_mask = attn_mask->to(at::kBool);
  }
  if (attn_mask) {
    return _masked_softmax(attn_scores, *attn_mask, attn_scores.dim() - 1, mask_type);
  } else {
    return _softmax_out(attn_scores, attn_scores, attn_scores.dim() - 1, false);
  }
}

Tensor bmm_nn(Tensor& out, const Tensor& a, const Tensor& b) {
  const std::array<int64_t, 3> newAShape = {
      a.sizes()[0] * a.sizes()[1], a.sizes()[2], a.sizes()[3]};
  auto a_ = a.view(newAShape);
  const std::array<int64_t, 3> newBShape = {
      b.sizes()[0] * b.sizes()[1], b.sizes()[2], b.sizes()[3]};
  auto b_ = b.view(newBShape);
  auto out_ = out.reshape({newAShape[0], newAShape[1], newBShape[2]});
  auto c_ = at::bmm_out(out_, a_, b_);
  return c_.view({a.size(0), a.size(1), a.size(2), b.size(3)});
}


Tensor transform0213_gemm_nt_bias(
    const Tensor& a,
    const Tensor& b,
    const Tensor& c,
    const Tensor& query) {
  if (query.is_nested()) {
    at::Tensor nested_a = _nested_from_padded(
        a, get_nested_tensor_impl(query)->get_nested_sizes(), true);
    return NestedTensor_times_Tensor_plus_Tensor_addmm(
        c, nested_a, b.t(), 1, 1);
  } else {
    const Tensor a_0213 = transform_0213(a);
    auto a_ = a_0213.view({a_0213.size(0) * a_0213.size(1), a_0213.size(2)});
    auto r_ = at::native::linear(a_, b, c);
    return r_.view({a_0213.size(0), a_0213.size(1), r_.size(1)});
  }
}

void debug_assert_shape(int line, const Tensor& t, c10::IntArrayRef shape) {
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
      (size_t)t.dim() == shape.size(),
      "(called from line ",
      line,
      ") ",
      "expected ",
      shape.size(),
      "-D tensor but got ",
      t.dim());
  if (t.is_nested()) {
    return;
  }
  for (auto idx : c10::irange(shape.size())) {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        shape[idx] == 0 || t.sizes()[idx] == shape[idx],
        "(called from line ",
        line,
        ") ",
        "expected dim ",
        idx,
        " to be ",
        shape[idx],
        " but got ",
        t.sizes()[idx]);
  }
}

Tensor qkv_projection(
    const Tensor& query,
    const Tensor& key,
    const Tensor& value,
    const int64_t embed_dim,
    const Tensor& qkv_weight) {
  // shape: [B, T, 3 x D]
  Tensor qkv;

  if (key.is_same(value)) {
    if (query.is_same(key)) {
      // self-attention
      qkv = gemm_nt(query, qkv_weight);
    } else {
      // encoder-decoder attention
      // TODO: is there a more efficient way to set this up?
      // TODO: can we stay nested insted of using cat? Probably just make a
      // NestedTensor out of the matmul results or something?
      auto q_kv_weight_s =
          at::native::split_with_sizes(qkv_weight, {embed_dim, embed_dim * 2}, 0);
      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
          q_kv_weight_s.size() == 2,
          "expected split to produce 2 tensors but it produced ",
          q_kv_weight_s.size());
      auto q = gemm_nt(query, q_kv_weight_s[0]);
      auto kv = gemm_nt(key, q_kv_weight_s[1]);
      qkv = at::cat({std::move(q), std::move(kv)}, 2);
    }
  } else {
    auto q_k_v_weight_s = at::native::chunk(qkv_weight, 3, 0);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        q_k_v_weight_s.size() == 3,
        "expected chunk to produce 3 tensors but it produced ",
        q_k_v_weight_s.size());
    // TODO: can we stay nested instead of using cat?
    auto q = gemm_nt(query, q_k_v_weight_s[0]);
    auto k = gemm_nt(key, q_k_v_weight_s[1]);
    auto v = gemm_nt(value, q_k_v_weight_s[2]);
    qkv = at::cat({std::move(q), std::move(k), std::move(v)}, 2);
  }

  return qkv;
}

// compute q = (q + q_bias) / sqrt(dim_per_head), k = k + k_bias, v = v + v_bias
std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cpu(
    const Tensor& qkv,
    const Tensor& qkv_bias,
    const int64_t num_head) {
  auto qkv_ = qkv.is_nested()
    ? c10::MaybeOwned<Tensor>::owned(qkv.to_padded_tensor(0))
    : c10::MaybeOwned<Tensor>::borrowed(qkv);
  auto B = qkv_->size(0);
  auto T = qkv_->size(1);
  auto _3D = qkv_->size(2);
  auto D = _3D / 3;
  TORCH_CHECK(D % num_head == 0);
  TORCH_CHECK(_3D % 3 == 0);
  const auto dim_per_head = D / num_head;
  auto q_k_v = at::empty({3, B, num_head, T, dim_per_head}, qkv_->options());

  const auto qkv_contig = qkv_->expect_contiguous();
  const auto qkv_bias_contig = qkv_bias.expect_contiguous();
  AT_DISPATCH_FLOATING_TYPES_AND2(
      ScalarType::Half,
      ScalarType::BFloat16,
      qkv_->scalar_type(),
      "transform_bias_rescale_qkv",
      [&] {
        scalar_t* qkv_data = qkv_contig->data_ptr<scalar_t>();
        scalar_t* qkv_bias_data = qkv_bias_contig->data_ptr<scalar_t>();
        scalar_t* q_k_v_data = q_k_v.data_ptr<scalar_t>();
        const scalar_t inv_sqrt_dim_per_head =
            1.0 / std::sqrt(static_cast<scalar_t>(dim_per_head));

        int64_t grain_size =
            std::max(internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
        parallel_for(
            0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
              transform_bias_rescale_qkv_inner_loop(
                  B,
                  T,
                  _3D,
                  D,
                  num_head,
                  dim_per_head,
                  qkv_data,
                  qkv_bias_data,
                  q_k_v_data,
                  inv_sqrt_dim_per_head,
                  begin,
                  end);
            });
      });
  auto q_k_v_s =
      at::native::split(q_k_v.view({3 * B, num_head, T, dim_per_head}), B, 0);
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(q_k_v_s.size() == 3);
  return std::make_tuple(q_k_v_s[0], q_k_v_s[1], q_k_v_s[2]);
}

std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
    const Tensor& query,
    const Tensor& key,
    const Tensor& value,
    const int64_t embed_dim,
    const int64_t num_head,
    const Tensor& qkv_weight,
    const Tensor& qkv_bias,
    const Tensor& proj_weight,
    const Tensor& proj_bias,
    const c10::optional<Tensor>& mask,
    bool need_weights,
    bool average_attn_weights,
    const c10::optional<int64_t> mask_type) {
  // query shape: [B, T, D]
  // qkv_weight shape: [3 * D, D]

  TORCH_CHECK(
      !mask || !query.is_nested(),
      "NestedTensor with mask is not supported yet");
  const auto D = embed_dim;
  TORCH_CHECK(
      query.dim() == 3,
      "expected 3-D `query`, got ",
      query.dim(),
      "-D tensor");
  TORCH_CHECK(
      query.is_nested() || query.sizes()[2] == embed_dim,
      "passed-in embed_dim ",
      embed_dim,
      " didn't match last dim of query ",
      query.sizes()[2]);
  TORCH_CHECK(
      key.dim() == 3,
      "expected 3-D `key`, got ",
      key.dim(),
      "-D tensor");
  TORCH_CHECK(
      value.dim() == 3,
      "expected 3-D `value`, got ",
      value.dim(),
      "-D tensor");
  TORCH_CHECK(
      query.is_nested() || key.is_nested() || value.is_nested() ||
          (query.sizes() == key.sizes() && key.sizes() == value.sizes()),
      "expected `query`/`key`/`value` shapes to match");
  TORCH_CHECK(
      qkv_weight.dim() == 2,
      "expected 2-D `qkv_weight`, got ",
      qkv_weight.dim(),
      "-D tensor");
  TORCH_CHECK(
      D * 3 == qkv_weight.sizes()[0],
      "expected `qkv_weight` first dim to be 3x embed_dim");
  TORCH_CHECK(
      D == qkv_weight.sizes()[1],
      "expected `qkv_weight` second dim to be embed_Dim");
  TORCH_CHECK(
      qkv_bias.dim() == 1,
      "expected 1-D `qkv_bias`, got ",
      qkv_bias.dim(),
      "-D tensor");
  TORCH_CHECK(
      qkv_bias.sizes()[0] == 3 * D,
      "expected `qkv_bias` first dim and first dim of query to be equal");
  TORCH_CHECK(D % num_head == 0, "`embed_dim` must divide evenly by `num_heads`");

#ifndef NDEBUG
  const auto B = query.is_nested()
      ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
      : query.sizes()[0];
  auto T = query.is_nested() ? 0 : query.sizes()[1];
  const auto dim_per_head = D / num_head;
#endif

  // shape: [B, T, 3 x D]
  auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);  // 这里是代码的核心 计算qkv


  if (!qkv.is_nested() && qkv.numel() == 0) {  // 检查qkv张量是否为空或为嵌套张量
    //qkv为空 不是嵌套且元素个数为0 
    if (query.is_nested()) {   // 检查query是否嵌套
      return std::make_tuple(Tensor(), Tensor()); // True则返回空的元组包含两个空Tensor
    }
    return std::make_tuple(at::empty_like(query), Tensor()); // 否则返回一个元组{与q相同形状的空向量,空Tensor}
  }

#ifndef NDEBUG
  if (!query.is_nested() || !qkv.is_nested()) {   // True则返回空的元组包含两个空Tensor
    if (query.is_nested()) {  // 如果q是嵌套的,将T赋值为qkv的第二个维度大小
      T = qkv.size(1);
    }
    debug_assert_shape(__LINE__, qkv, {B, T, 3 * D});
  }
#endif

#ifdef DEBUG_PRINT_EACH_STEP
  if (!qkv.is_nested()) {  // qkv不嵌套则输出qkv
    std::cerr << "qkv: " << qkv << std::endl;
  }
#endif
  // shape: 3 x [B, num_head, T, dim_per_head]
  auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);  // 重构qkv为含头形式
  qkv = Tensor(); // Not used any more, allow free
  auto& q = std::get<0>(q_k_v);  // 读取q
  const auto& k = std::get<1>(q_k_v);  // 读取k
  const auto& v = std::get<2>(q_k_v);  // 读取v
#ifndef NDEBUG
  debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
  debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
  debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
  std::cerr << "q: " << q << std::endl;
  std::cerr << "k: " << k << std::endl;
  std::cerr << "v: " << v << std::endl;
#endif

  // shape: [B, num_head, T, T]
  auto qkt = bmm_nt(q, k);  // 计算qk的自注意力
  // q & k are dead but cannot be freed because they were packed with v

#ifndef NDEBUG
  debug_assert_shape(__LINE__, qkt, {B, num_head, T, T});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
  std::cerr << "qkt: " << qkt << std::endl;
#endif

  // shape: [B, num_head, T, T]
  // TODO: long-term, have a kernel that works with
  // NestedTensor directly if there is no mask passed
  qkt = masked_softmax(qkt, mask, query, mask_type);  // 计算qk掩码注意力分数

#ifdef DEBUG_PRINT_EACH_STEP
  std::cerr << "qkt after softmax: " << qkt << std::endl;
#endif

  // shape: [B, num_head, T, dim_per_head]
  // reuse storage for q; we're done with it
  auto attn_ctx = bmm_nn(q, qkt, v);  // 这个写法和bmm_nt不太一样 这里应该是为了记录qkv输出
  // qkv is not dead; we just reused storage for q!
  if (!need_weights) {
    qkt = Tensor();
  }
#ifndef NDEBUG
  debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
  std::cerr << "attn_ctx: " << attn_ctx << std::endl;
#endif

  // shape: [B, T, D]
  // Fuse transform_0213 inside
  auto proj = transform0213_gemm_nt_bias(
      attn_ctx, proj_weight, proj_bias, query);  // 这里应该是重构为BTD形式 拼接多头
#ifndef NDEBUG
  debug_assert_shape(__LINE__, proj, {B, T, D});
#endif
  if (need_weights && average_attn_weights) {
    // weights are not needed for full transformer, so don't worry too
    // much about performance -- we implement this just to make use
    // cases that don't disable need_weights still get some speedup.
    qkt = qkt.sum(1);
    qkt /= num_head;
  }
  return std::make_tuple(std::move(proj), std::move(qkt));
}

int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
        const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, c10::optional<double> scale){
  return static_cast<int64_t>(sdp::SDPBackend::math);
}

int64_t _fused_sdp_choice_meta(
    const Tensor& query_,
    const Tensor& key,
    const Tensor& value,
    const c10::optional<Tensor>& attn_mask_,
    double dropout_p,
    bool is_causal,
    c10::optional<double> scale) {
  auto query_key_set = query_.key_set();
  bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
  if (has_cuda) {
    auto choice_int = _fused_sdp_choice_stub(
        at::kCUDA,
        query_,
        key,
        value,
        attn_mask_,
        dropout_p,
        is_causal,
        scale);
    return choice_int;
  }
  return static_cast<int64_t>(sdp::SDPBackend::math);
}
namespace {

inline void validate_sdpa_input(
    const Tensor& query_,
    const Tensor& key,
    const Tensor& value,
    const c10::optional<Tensor>& attn_mask_,
    double dropout_p,
    bool is_causal,
    c10::optional<double> scale) {
  TORCH_CHECK(
      query_.dtype() == key.dtype() && query_.dtype() == value.dtype(),
      "Expected query, key, and value to have the same dtype, but got query.dtype: ",
      query_.dtype(), " key.dtype: ", key.dtype(), " and value.dtype: ", value.dtype(), " instead.");
  TORCH_CHECK(
      query_.device() == key.device() && query_.device() == value.device(),
      "Expected query, key, and value to have the same device type, but got query.device: ",
      query_.device(), " key.device: ", key.device(), " and value.device: ", value.device(), " instead.");
  TORCH_CHECK(
      query_.dim() >= 2 && key.dim() >= 2 && value.dim() >= 2,
      "Expected query, key, and value to all be  at least 2 dimensional, but got query.dim: ",
      query_.dim(), " key.dim: ", key.dim(), " and value.dim: ", value.dim(), " instead.");
  if (attn_mask_.has_value()){
    auto mask_dtype = attn_mask_->dtype();
    TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == query_.dtype(),
      "Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: ",
      mask_dtype, " and  query.dtype: ", query_.dtype(), " instead.");
    TORCH_CHECK(
      !query_.is_nested() && !key.is_nested(),
      "Scaled_dot_product_attention: Nested tensors for query / key are not supported "
      "when an explicit attn_mask is set");
  }
  return;
}
// This function is used to produce an attn_mask
// in a standard format that can be consumed by both
// the math and memory efficient attn_mask implementation
//  Args:
//    attn_mask: attn_mask of shape (B, L, S) or (L, S) or (B, N_heads, L, S)
c10::optional<Tensor> convert_boolean_attn_mask(const c10::optional<Tensor>& attn_mask, caffe2::TypeMeta dtype) {
  // Pass through
  if(!attn_mask.has_value()){
    return c10::nullopt;
  }
  // Convert boolean mask to additive mask; need to invert mask to indicate what
  // to mask *out*.
  if (attn_mask->dtype() == at::kBool) {
    auto new_attn_mask = at::zeros_like(attn_mask.value(), dtype);
    // TODO Use the max type of the input and output
    new_attn_mask.masked_fill_(
        attn_mask->logical_not(), -std::numeric_limits<double>::infinity());
    return new_attn_mask;
  }
  // Otherwise, attn_mask represents an additive attention tensor
  return attn_mask;
}
// Memory Efficient Attention requires a padded attn mask bias
// This function pads the attn_mask bias to be a multiple of 16
// Then slices the padded bias to the original size
// We apply this function to the top level SDPA so that
// if padding is done it will be tracked for backward automatically

template <int alignment>
bool is_aligned(const SymInt& size){
  return size % alignment == 0;
}

template <int alignment>
at::Tensor pad_bias(const at::Tensor& attn_bias) {
  auto last_dim_size = attn_bias.sym_size(-1);
  auto pad_count = alignment - (last_dim_size % alignment);
  auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
  return padded_bias.slice_symint(-1, 0, last_dim_size);
}

at::Tensor preprocess_mask(
    const at::Tensor& mask,
    const at::Tensor& query,
    const at::Tensor& key,
    const at::Tensor& value) {
  constexpr int mem_eff_alignment = 16;
  // Expand to 4d case
  at::Tensor attn_mask = mask.expand_symint(
      {query.sym_size(0),
       query.sym_size(1),
       query.sym_size(2),
       key.sym_size(2)});

  bool aligned_last_dim = is_aligned<mem_eff_alignment>(attn_mask.sym_size(-1));
  // Apply pad_bias and store the result in attn_mask
  if (!aligned_last_dim) {
    return pad_bias<mem_eff_alignment>(attn_mask);
  }
  // Check and make the tensor contiguous if needed
  if (attn_mask.sym_stride(0) % 16 != 0 || attn_mask.sym_stride(1) % 16 != 0 ||
      attn_mask.sym_stride(2) % 16 != 0) {
    return attn_mask.contiguous();
  }

  return attn_mask;
}

} // namespace

// Computes scaled dot product attention on query, key and value tensors, using
// an optional attention mask if passed, and applying dropout if a probability
// greater than 0.0 is specified.
//
// Args:
//     query (Tensor): Query tensor; shape (N, ..., L, E)
//     key (Tensor): Key tensor; shape (N, ..., S, E)
//     value (Tensor): Value tensor; shape (N, ..., S, E)
//     attn_mask (optional Tensor): Attention mask; shape (N, ..., L, S) or (L, S). Currently, only a boolean mask
//         is supported, where a value of True indicates that the element *should* take part in attention.
//     dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
//     need_attn_weights (bool): If true, the second return value will contain the attention weights used;
//         otherwise, the second return value is unspecified
//     is_causal (bool): If true, assumes causal attention masking; for this case, attn_mask should not be set.
//         TODO: Consider removing this flag before promoting this function to the public API. It's possible
//         to get specialized support for causal masks (and other types of masking e.g. local attention / block
//         sparse masks) via tensor subclassing, allowing for a leaner API.
//
// Returns a tuple containing:
//     output (Tensor): Attention output; shape (N, ..., L, E)
//     attn_weights (Tensor): Attention weighting; shape (N, ..., L, S)
//
// Shape legend:
//     N: Batch size
//     ...: Any number of other batch dimensions (optional)
//     S: Source sequence length
//     L: Target sequence length
//     E: Embedding dimension
Tensor scaled_dot_product_attention(
    const Tensor& query_,
    const Tensor& key,
    const Tensor& value,
    const c10::optional<Tensor>& attn_mask_,
    double dropout_p,
    bool is_causal,
    c10::optional<double> scale) {
  validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
  int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
  if (query_.device().type() == DeviceType::CUDA){
    choice_int = _fused_sdp_choice_stub(query_.device().type(),
      query_, key, value, attn_mask_, dropout_p, is_causal, scale);
  }
  sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
  c10::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
  switch (backend) {
    case sdp::SDPBackend::flash_attention: {
      auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
          query_, key, value, dropout_p, is_causal, false /*return_debug_mask*/, scale);
      return std::get<0>(out_lse_softmax);
    }
    case sdp::SDPBackend::efficient_attention: {
      bool compute_logsumexp =
          (query_.requires_grad() || key.requires_grad() ||
           value.requires_grad());
      if (attn_mask.has_value()) {
        attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);;
      }
      auto out_and_lse = at::_scaled_dot_product_efficient_attention(
          query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
      return std::get<0>(out_and_lse);
    }
    case sdp::SDPBackend::math:
      return std::get<0>(at::_scaled_dot_product_attention_math(
          query_,
          key,
          value,
          attn_mask,
          dropout_p,
          is_causal,
          c10::nullopt, /*dropout_mask*/
          scale));
    default:
      TORCH_CHECK(
          false,
          "No viable backend for scaled_dot_product_attention was found.");
      return Tensor();
  }
}

std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
        const Tensor& query_, const Tensor& key, const Tensor& value,
        const c10::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
        const c10::optional<Tensor>& dropout_mask, c10::optional<double> scale) {
  C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
  if (query_.is_nested() || key.is_nested() || value.is_nested()) {
    TORCH_CHECK(
        query_.is_contiguous() && key.is_contiguous() &&
            value.is_contiguous(),
        "scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
  }
    auto attn_mask = attn_mask_;
    // Naive, composite implementation defined here.

    // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
    bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
    const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();

    const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
    if (is_causal) {
        TORCH_CHECK(!attn_mask.has_value(),
                "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
        TORCH_CHECK(!query.is_nested() && !key.is_nested(),
                "_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");

        // Replace attn_mask with causal mask; lower triangular elements take part in attention.
        const auto L = query.sym_size(-2), S = key.sym_size(-2);
        attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
        attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
    }
    auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
    if (attn_mask.has_value()) {
      if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
        attn = attn.add(*attn_mask);
      } else {
        attn.add_(*attn_mask);
      }
    }
    attn = at::softmax(attn, -1);
    if (dropout_p > 0.0) {
      if (dropout_mask.has_value()) {
        // In order to validate the correctness of the fused kernels, we need to
        // use the same dropout mask in order to compare the results.
        TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
        attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
        auto dropout_scaling = 1.0 / (1 - dropout_p);
        return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
      } else {
        attn = at::dropout(attn, dropout_p, true);
      }
    }

    return std::make_tuple(at::matmul(attn, value), attn);
}

Tensor triton_multi_head_attention(
    const Tensor& query,
    const Tensor& key,
    const Tensor& value,
    const int64_t embed_dim,
    const int64_t num_head,
    const Tensor& qkv_weight,
    const Tensor& qkv_bias,
    const Tensor& proj_weight,
    const Tensor& proj_bias,
    const c10::optional<Tensor>& mask) {
  // query shape: [B, T, D]
  // qkv_weight shape: [3 * D, D]
  TORCH_CHECK(!mask, "Only causal mask is supported for Triton.");

  const auto D = embed_dim;
  TORCH_CHECK(
      query.dim() == 3,
      "expected 3-D `query`, got ",
      query.dim(),
      "-D tensor");
  TORCH_CHECK(
      query.sizes()[2] == embed_dim,
      "passed-in embed_dim ",
      embed_dim,
      " didn't match last dim of query ",
      query.sizes()[2]);
  TORCH_CHECK(
      key.dim() == 3,
      "expected 3-D `key`, got ",
      key.dim(),
      "-D tensor");
  TORCH_CHECK(
      value.dim() == 3,
      "expected 3-D `value`, got ",
      value.dim(),
      "-D tensor");
  TORCH_CHECK(
          query.sizes() == key.sizes() && key.sizes() == value.sizes(),
      "expected `query`/`key`/`value` shapes to match");
  TORCH_CHECK(
      qkv_weight.dim() == 2,
      "expected 2-D `qkv_weight`, got ",
      qkv_weight.dim(),
      "-D tensor");
  TORCH_CHECK(
      D * 3 == qkv_weight.sizes()[0],
      "expected `qkv_weight` first dim to be 3x embed_dim");
  TORCH_CHECK(
      D == qkv_weight.sizes()[1],
      "expected `qkv_weight` second dim to be embed_Dim");

#ifndef NDEBUG
  const auto B = query.is_nested()
      ? get_nested_tensor_impl(query)->get_nested_sizes().size(0)
      : query.sizes()[0];
  auto T = query.is_nested() ? 0 : query.sizes()[1];
  const auto dim_per_head = D / num_head;
#endif

  // shape: [B, T, 3 x D]
  auto qkv = qkv_projection(query, key, value, embed_dim, qkv_weight);

  // shape: 3 x [B, num_head, T, dim_per_head]
  auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
  qkv = Tensor(); // Not used any more, allow free
  auto& q = std::get<0>(q_k_v);
  const auto& k = std::get<1>(q_k_v);
  const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG
  debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
  debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
  debug_assert_shape(__LINE__, v, {B, num_head, T, dim_per_head});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
  std::cerr << "q: " << q << std::endl;
  std::cerr << "k: " << k << std::endl;
  std::cerr << "v: " << v << std::endl;
#endif

  auto attn_ctx = at::_triton_scaled_dot_attention(q, k, v);

#ifndef NDEBUG
  debug_assert_shape(__LINE__, attn_ctx, {B, num_head, T, dim_per_head});
#endif
#ifdef DEBUG_PRINT_EACH_STEP
  std::cerr << "attn_ctx: " << attn_ctx << std::endl;
#endif

  // shape: [B, T, D]
  // Fuse transform_0213 inside
  auto proj = transform0213_gemm_nt_bias(
      attn_ctx, proj_weight, proj_bias, query);
#ifndef NDEBUG
  debug_assert_shape(__LINE__, proj, {B, T, D});
#endif
  return proj;
}
} // namespace native
} // namespace at