# AOT ID: ['5_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
from torch._C import _cuda_getCurrentRawStream as get_raw_stream



# kernel path: /tmp/torchinductor_shingokuga/x3/cx356wjlnqsglzffwnteeiqhrddtc33m4puny7vrnxr7j22ksni3.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten.expand, aten.view, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add => add_30
#   hidden => convert_element_type_4
#   hidden_states => mul_37
#   mul => mul_32
#   pow_1 => pow_1
#   rsqrt => rsqrt
#   special_tokens => expand
#   to => convert_element_type_3
#   variance => mean
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg4_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %addmm : Tensor "bf16[4*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %buf1 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf1]
#   %arg7_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg7_1]
#   %expand : Tensor "bf16[1, s27, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg4_1, [1, %arg0_1, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, s27, 4, 1024][4096*s27, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, %arg0_1, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, s27, 5, 1024][5120*s27, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [%arg0_1, 5, 1024]), kwargs = {})
#   %convert_element_type_3 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_2, torch.float32), kwargs = {})
#   %pow_1 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_3, 2), kwargs = {})
#   %mean : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
#   %add_30 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
#   %rsqrt : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_30,), kwargs = {})
#   %mul_32 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_2, %rsqrt), kwargs = {})
#   %convert_element_type_4 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_32, torch.bfloat16), kwargs = {})
#   %mul_37 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_4, %arg7_1), kwargs = {})
#   return %buf1,%mul_37
triton_red_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0 = async_compile.triton('triton_red_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 24031232}}
)
@triton.jit
def triton_red_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    r0_numel = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = (xindex % 5)
    x1 = xindex // 5
    _tmp14 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    x3 = xindex
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        tmp0 = x0
        tmp1 = tl.full([1, 1], 0, tl.int64)
        tmp2 = tmp0 >= tmp1
        tmp3 = tl.full([1, 1], 1, tl.int64)
        tmp4 = tmp0 < tmp3
        tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), r0_mask & tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp6 = tmp0 >= tmp3
        tmp7 = tl.full([1, 1], 5, tl.int64)
        tmp8 = tmp0 < tmp7
        tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp10 = tl.where(tmp4, tmp5, tmp9)
        tmp11 = tmp10.to(tl.float32)
        tmp12 = tmp11 * tmp11
        tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
        tmp15 = _tmp14 + tmp13
        _tmp14 = tl.where(r0_mask & xmask, tmp15, _tmp14)
    tmp14 = tl.sum(_tmp14, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        tmp35 = tl.load(in_ptr2 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp16 = x0
        tmp17 = tl.full([1, 1], 0, tl.int64)
        tmp18 = tmp16 >= tmp17
        tmp19 = tl.full([1, 1], 1, tl.int64)
        tmp20 = tmp16 < tmp19
        tmp21 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), r0_mask & tmp20 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp22 = tmp16 >= tmp19
        tmp23 = tl.full([1, 1], 5, tl.int64)
        tmp24 = tmp16 < tmp23
        tmp25 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), r0_mask & tmp22 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp26 = tl.where(tmp20, tmp21, tmp25)
        tmp27 = tmp26.to(tl.float32)
        tmp28 = tl.full([1, 1], 1024.0, tl.float32)
        tmp29 = (tmp14 / tmp28)
        tmp30 = tl.full([1, 1], 1e-05, tl.float32)
        tmp31 = tmp29 + tmp30
        tmp32 = libdevice.rsqrt(tmp31)
        tmp33 = tmp27 * tmp32
        tmp34 = tmp33.to(tl.float32)
        tmp36 = tmp34 * tmp35
        tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp36, r0_mask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/e6/ce6hpumgziia2icpvjapj7kqaj5e7mtqg5cemsytau7d7tzed6xy.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_12
#   k_embed => add_187
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_14
#   key_states_3 => clone_1
#   mul_2 => mul_99
#   mul_3 => mul_110
#   mul_4 => mul_115
#   mul_5 => mul_126
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_11
#   q_embed => add_151
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_13
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm : Tensor "bf16[5*s27, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm]
#   %arg5_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %arg6_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg6_1]
#   %view_4 : Tensor "bf16[s27, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [%arg0_1, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[s27, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [%arg0_1, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[s27, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [%arg0_1, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[s27, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [%arg0_1, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_12 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_12, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_99 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg : Tensor "f32[s27, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg6_1, [%iota]), kwargs = {})
#   %mul_110 : Tensor "f32[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_151 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_99, %mul_110), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_151, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_115 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_12, %index), kwargs = {})
#   %neg_1 : Tensor "f32[s27, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_126 : Tensor "f32[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_187 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_115, %mul_126), kwargs = {})
#   %convert_element_type_14 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_187, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_14,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[s27, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [%arg0_1, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[s27, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [%arg0_1, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf6
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 85813760}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x4 = xindex
    x0 = (xindex % 128)
    x2 = ((xindex // 2048) % 5)
    x5 = xindex // 128
    x1 = ((xindex // 128) % 16)
    x3 = xindex // 10240
    tmp0 = tl.load(in_ptr0 + (x4), xmask).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x5 + (x0)), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x5 + ((-64) + x0)), tmp15 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 640*x1 + 10240*x3), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/uk/cukqhhiymy2hh464ru3ix5jlmsvkwj4oug2a4z462gjwnyfb33bj.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_12
#   k_embed => add_187
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_14
#   key_states_3 => clone_1
#   mul_2 => mul_99
#   mul_3 => mul_110
#   mul_4 => mul_115
#   mul_5 => mul_126
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_11
#   q_embed => add_151
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_13
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm_1 : Tensor "bf16[5*s27, 256][256, 1]cuda:0" = PlaceHolder[target=mm_1]
#   %arg5_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %arg6_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg6_1]
#   %view_4 : Tensor "bf16[s27, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [%arg0_1, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[s27, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [%arg0_1, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[s27, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [%arg0_1, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[s27, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [%arg0_1, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_12 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_12, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_99 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg : Tensor "f32[s27, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg6_1, [%iota]), kwargs = {})
#   %mul_110 : Tensor "f32[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_151 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_99, %mul_110), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_151, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_115 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_12, %index), kwargs = {})
#   %neg_1 : Tensor "f32[s27, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_126 : Tensor "f32[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_187 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_115, %mul_126), kwargs = {})
#   %convert_element_type_14 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_187, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_14,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[s27, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [%arg0_1, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[s27, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [%arg0_1, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf7
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 10728960}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x4 = xindex
    x0 = (xindex % 128)
    x2 = ((xindex // 256) % 5)
    x5 = xindex // 128
    x1 = ((xindex // 128) % 2)
    x3 = xindex // 1280
    tmp0 = tl.load(in_ptr0 + (x4), xmask).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x5 + (x0)), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x5 + ((-64) + x0)), tmp15 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 640*x1 + 1280*x3), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/l3/cl364hkejty4pmsvqnxfiooqrrphp6qc4njpvydbkkx2g6doxuh4.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_12
#   k_embed => add_187
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_14
#   key_states_3 => clone_1
#   mul_2 => mul_99
#   mul_3 => mul_110
#   mul_4 => mul_115
#   mul_5 => mul_126
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_11
#   q_embed => add_151
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_13
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm_2 : Tensor "bf16[5*s27, 256][256, 1]cuda:0" = PlaceHolder[target=mm_2]
#   %view_4 : Tensor "bf16[s27, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [%arg0_1, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[s27, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [%arg0_1, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[s27, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [%arg0_1, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[s27, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [%arg0_1, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_12 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_12, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_99 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg : Tensor "f32[s27, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg6_1, [%iota]), kwargs = {})
#   %mul_110 : Tensor "f32[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_151 : Tensor "f32[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_99, %mul_110), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[s27, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_151, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_115 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_12, %index), kwargs = {})
#   %neg_1 : Tensor "f32[s27, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_126 : Tensor "f32[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_187 : Tensor "f32[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_115, %mul_126), kwargs = {})
#   %convert_element_type_14 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_187, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_14,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[s27, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [%arg0_1, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[s27, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [%arg0_1, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[s27, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[s27, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf8
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 6435840}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 5)
    x2 = ((xindex // 640) % 2)
    x3 = xindex // 1280
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 256*x1 + 1280*x3), xmask).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/as/castk3p3frzh7mm4psypn73pgzhmmta5v2f2455qt2qib5gxtf2z.py
# Topologically Sorted Source Nodes: [transpose_3, attn_output_1], Original ATen: [aten.transpose, aten.clone]
# Source node to ATen node mapping:
#   attn_output_1 => clone_3
#   transpose_3 => permute_7
# Graph fragment:
#   %getitem_4 : Tensor "bf16[s27, 16, 5, 128][10240, 640, 128, 1]cuda:0" = PlaceHolder[target=getitem_4]
#   %permute_7 : Tensor "bf16[s27, 5, 16, 128][10240, 128, 640, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%getitem_4, [0, 2, 1, 3]), kwargs = {})
#   %clone_3 : Tensor "bf16[s27, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_7,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_3
triton_poi_fused_clone_transpose_4 = async_compile.triton('triton_poi_fused_clone_transpose_4', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_transpose_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 51486720}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_transpose_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 16)
    x2 = ((xindex // 2048) % 5)
    x3 = xindex // 10240
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 640*x1 + 10240*x3), xmask).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zx/czxpkhohrptsrrumc76vqw7ssh5lcasv4iafpf6rpjykxvzaj24g.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_4 => add_268
#   attn_output_3 => view_14
#   hidden_1 => convert_element_type_18
#   hidden_states_1 => add_251
#   hidden_states_2 => mul_206
#   mul_6 => mul_201
#   pow_2 => pow_2
#   rsqrt_1 => rsqrt_1
#   special_tokens => expand
#   to_6 => convert_element_type_17
#   variance_1 => mean_1
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg4_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %addmm : Tensor "bf16[4*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %mm_3 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %buf17 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf17]
#   %arg12_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg12_1]
#   %expand : Tensor "bf16[1, s27, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg4_1, [1, %arg0_1, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, s27, 4, 1024][4096*s27, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, %arg0_1, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, s27, 5, 1024][5120*s27, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [%arg0_1, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_251 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %convert_element_type_17 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_251, torch.float32), kwargs = {})
#   %pow_2 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_17, 2), kwargs = {})
#   %mean_1 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
#   %add_268 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
#   %rsqrt_1 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_268,), kwargs = {})
#   %mul_201 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %rsqrt_1), kwargs = {})
#   %convert_element_type_18 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_201, torch.bfloat16), kwargs = {})
#   %mul_206 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_18, %arg12_1), kwargs = {})
#   return %buf17,%mul_206
triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5 = async_compile.triton('triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 32612352}}
)
@triton.jit
def triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    r0_numel = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = (xindex % 5)
    x1 = xindex // 5
    x3 = xindex
    _tmp16 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        tmp11 = tl.load(in_ptr2 + (r0_2 + 1024*x3), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp0 = x0
        tmp1 = tl.full([1, 1], 0, tl.int64)
        tmp2 = tmp0 >= tmp1
        tmp3 = tl.full([1, 1], 1, tl.int64)
        tmp4 = tmp0 < tmp3
        tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), r0_mask & tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp6 = tmp0 >= tmp3
        tmp7 = tl.full([1, 1], 5, tl.int64)
        tmp8 = tmp0 < tmp7
        tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp10 = tl.where(tmp4, tmp5, tmp9)
        tmp12 = tmp10 + tmp11
        tmp13 = tmp12.to(tl.float32)
        tmp14 = tmp13 * tmp13
        tmp15 = tl.broadcast_to(tmp14, [XBLOCK, R0_BLOCK])
        tmp17 = _tmp16 + tmp15
        _tmp16 = tl.where(r0_mask & xmask, tmp17, _tmp16)
    tmp16 = tl.sum(_tmp16, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        tmp29 = tl.load(in_ptr2 + (r0_2 + 1024*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp39 = tl.load(in_ptr3 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp18 = x0
        tmp19 = tl.full([1, 1], 0, tl.int64)
        tmp20 = tmp18 >= tmp19
        tmp21 = tl.full([1, 1], 1, tl.int64)
        tmp22 = tmp18 < tmp21
        tmp23 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), r0_mask & tmp22 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp24 = tmp18 >= tmp21
        tmp25 = tl.full([1, 1], 5, tl.int64)
        tmp26 = tmp18 < tmp25
        tmp27 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), r0_mask & tmp24 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp28 = tl.where(tmp22, tmp23, tmp27)
        tmp30 = tmp28 + tmp29
        tmp31 = tmp30.to(tl.float32)
        tmp32 = tl.full([1, 1], 1024.0, tl.float32)
        tmp33 = (tmp16 / tmp32)
        tmp34 = tl.full([1, 1], 1e-05, tl.float32)
        tmp35 = tmp33 + tmp34
        tmp36 = libdevice.rsqrt(tmp35)
        tmp37 = tmp31 * tmp36
        tmp38 = tmp37.to(tl.float32)
        tmp40 = tmp38 * tmp39
        tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp40, r0_mask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/x7/cx7wozzxnvoslqulhofuzhlauztzmrbx7mbypdbjhmqqdveveeqc.py
# Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
# Source node to ATen node mapping:
#   linear_5 => view_16
#   linear_6 => view_18
#   mul_8 => mul_241
#   silu => add_299, convert_element_type_21, convert_element_type_22, div, exp, neg_2
# Graph fragment:
#   %mm_4 : Tensor "bf16[5*s27, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_4]
#   %mm_5 : Tensor "bf16[5*s27, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_5]
#   %view_16 : Tensor "bf16[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_4, [%arg0_1, 5, 4096]), kwargs = {})
#   %convert_element_type_21 : Tensor "f32[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_16, torch.float32), kwargs = {})
#   %neg_2 : Tensor "f32[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%convert_element_type_21,), kwargs = {})
#   %exp : Tensor "f32[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%neg_2,), kwargs = {})
#   %add_299 : Tensor "f32[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%exp, 1), kwargs = {})
#   %div : Tensor "f32[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_21, %add_299), kwargs = {})
#   %convert_element_type_22 : Tensor "bf16[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%div, torch.bfloat16), kwargs = {})
#   %view_18 : Tensor "bf16[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_5, [%arg0_1, 5, 4096]), kwargs = {})
#   %mul_241 : Tensor "bf16[s27, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %view_18), kwargs = {})
#   return %mul_241
triton_poi_fused__unsafe_view_mul_silu_6 = async_compile.triton('triton_poi_fused__unsafe_view_mul_silu_6', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__unsafe_view_mul_silu_6', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 137297920}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__unsafe_view_mul_silu_6(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = -tmp1
    tmp3 = libdevice.exp(tmp2)
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = (tmp1 / tmp5)
    tmp7 = tmp6.to(tl.float32)
    tmp9 = tmp7 * tmp8
    tl.store(in_out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/64/c64n3xvqvsowuhzzplwtkn3r7mnhfplwhhvnhkomnyoksln7wx5t.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_6 => add_345
#   attn_output_3 => view_14
#   hidden_2 => convert_element_type_28
#   hidden_states_1 => add_251
#   hidden_states_3 => view_20
#   hidden_states_4 => add_328
#   hidden_states_5 => mul_278
#   mul_9 => mul_273
#   pow_3 => pow_3
#   rsqrt_2 => rsqrt_2
#   special_tokens => expand
#   to_8 => convert_element_type_27
#   variance_2 => mean_2
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg4_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %addmm : Tensor "bf16[4*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %mm_3 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %buf23 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf23]
#   %arg16_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg16_1]
#   %expand : Tensor "bf16[1, s27, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg4_1, [1, %arg0_1, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, s27, 4, 1024][4096*s27, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, %arg0_1, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, s27, 5, 1024][5120*s27, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [%arg0_1, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_251 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %view_20 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_328 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_251, %view_20), kwargs = {})
#   %convert_element_type_27 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_328, torch.float32), kwargs = {})
#   %pow_3 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_27, 2), kwargs = {})
#   %mean_2 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
#   %add_345 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
#   %rsqrt_2 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_345,), kwargs = {})
#   %mul_273 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_328, %rsqrt_2), kwargs = {})
#   %convert_element_type_28 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_273, torch.bfloat16), kwargs = {})
#   %mul_278 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_28, %arg16_1), kwargs = {})
#   return %buf23,%mul_278
triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7 = async_compile.triton('triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 9, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 41193472}}
)
@triton.jit
def triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    r0_numel = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = (xindex % 5)
    x1 = xindex // 5
    x3 = xindex
    _tmp18 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        tmp11 = tl.load(in_ptr2 + (r0_2 + 1024*x3), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr3 + (r0_2 + 1024*x3), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp0 = x0
        tmp1 = tl.full([1, 1], 0, tl.int64)
        tmp2 = tmp0 >= tmp1
        tmp3 = tl.full([1, 1], 1, tl.int64)
        tmp4 = tmp0 < tmp3
        tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), r0_mask & tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp6 = tmp0 >= tmp3
        tmp7 = tl.full([1, 1], 5, tl.int64)
        tmp8 = tmp0 < tmp7
        tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), r0_mask & tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp10 = tl.where(tmp4, tmp5, tmp9)
        tmp12 = tmp10 + tmp11
        tmp14 = tmp12 + tmp13
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tmp15 * tmp15
        tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
        tmp19 = _tmp18 + tmp17
        _tmp18 = tl.where(r0_mask & xmask, tmp19, _tmp18)
    tmp18 = tl.sum(_tmp18, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        tmp31 = tl.load(in_ptr2 + (r0_2 + 1024*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp33 = tl.load(in_ptr3 + (r0_2 + 1024*x3), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp43 = tl.load(in_ptr4 + (r0_2), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp20 = x0
        tmp21 = tl.full([1, 1], 0, tl.int64)
        tmp22 = tmp20 >= tmp21
        tmp23 = tl.full([1, 1], 1, tl.int64)
        tmp24 = tmp20 < tmp23
        tmp25 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), r0_mask & tmp24 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp26 = tmp20 >= tmp23
        tmp27 = tl.full([1, 1], 5, tl.int64)
        tmp28 = tmp20 < tmp27
        tmp29 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), r0_mask & tmp26 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp30 = tl.where(tmp24, tmp25, tmp29)
        tmp32 = tmp30 + tmp31
        tmp34 = tmp32 + tmp33
        tmp35 = tmp34.to(tl.float32)
        tmp36 = tl.full([1, 1], 1024.0, tl.float32)
        tmp37 = (tmp18 / tmp36)
        tmp38 = tl.full([1, 1], 1e-05, tl.float32)
        tmp39 = tmp37 + tmp38
        tmp40 = libdevice.rsqrt(tmp39)
        tmp41 = tmp35 * tmp40
        tmp42 = tmp41.to(tl.float32)
        tmp44 = tmp42 * tmp43
        tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp44, r0_mask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ok/cokb4yy6ogms7dkxacx6b4tftgffel6a2edyvidyuhbns6waiord.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_10 => add_583
#   attn_output_3 => view_14
#   attn_output_7 => view_32
#   hidden_3 => convert_element_type_42
#   hidden_states_1 => add_251
#   hidden_states_3 => view_20
#   hidden_states_4 => add_328
#   hidden_states_6 => add_566
#   hidden_states_7 => mul_447
#   mul_15 => mul_442
#   pow_4 => pow_4
#   rsqrt_3 => rsqrt_3
#   special_tokens => expand
#   to_14 => convert_element_type_41
#   variance_3 => mean_3
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg4_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %addmm : Tensor "bf16[4*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %mm_3 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %add_566 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_566]
#   %buf40 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf40]
#   %arg21_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg21_1]
#   %expand : Tensor "bf16[1, s27, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg4_1, [1, %arg0_1, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, s27, 4, 1024][4096*s27, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, %arg0_1, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, s27, 5, 1024][5120*s27, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [%arg0_1, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_251 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %view_20 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_328 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_251, %view_20), kwargs = {})
#   %view_32 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_10, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_566 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_328, %view_32), kwargs = {})
#   %convert_element_type_41 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_566, torch.float32), kwargs = {})
#   %pow_4 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_41, 2), kwargs = {})
#   %mean_3 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
#   %add_583 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
#   %rsqrt_3 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_583,), kwargs = {})
#   %mul_442 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_566, %rsqrt_3), kwargs = {})
#   %convert_element_type_42 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_442, torch.bfloat16), kwargs = {})
#   %mul_447 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_42, %arg21_1), kwargs = {})
#   return %add_566,%buf40,%mul_447
triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 66936832}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = (xindex % 5)
    r0_2 = r0_index
    x1 = xindex // 5
    x3 = xindex
    tmp11 = tl.load(in_out_ptr0 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp13 = tl.load(in_ptr2 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp15 = tl.load(in_ptr3 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp30 = tl.load(in_ptr4 + (r0_2), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.where(tmp4, tmp5, tmp9)
    tmp12 = tmp10 + tmp11
    tmp14 = tmp12 + tmp13
    tmp16 = tmp14 + tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp18 = tmp17 * tmp17
    tmp19 = tl.broadcast_to(tmp18, [XBLOCK, R0_BLOCK])
    tmp21 = tl.where(xmask, tmp19, 0)
    tmp22 = tl.sum(tmp21, 1)[:, None].to(tl.float32)
    tmp23 = tl.full([1, 1], 1024.0, tl.float32)
    tmp24 = (tmp22 / tmp23)
    tmp25 = tl.full([1, 1], 1e-05, tl.float32)
    tmp26 = tmp24 + tmp25
    tmp27 = libdevice.rsqrt(tmp26)
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28.to(tl.float32)
    tmp31 = tmp29 * tmp30
    tl.store(in_out_ptr0 + (r0_2 + 1024*x3), tmp16, xmask)
    tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp31, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7w/c7wxjknsknw3dpdvlmolkss6n4y563eojivqm5fvdqyufon25und.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_12 => add_660
#   hidden_4 => convert_element_type_52
#   hidden_states_10 => mul_519
#   hidden_states_8 => view_38
#   hidden_states_9 => add_643
#   mul_18 => mul_514
#   pow_5 => pow_5
#   rsqrt_4 => rsqrt_4
#   to_16 => convert_element_type_51
#   variance_4 => mean_4
# Graph fragment:
#   %add_566 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_566]
#   %mm_13 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %buf46 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf46]
#   %arg25_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg25_1]
#   %view_38 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_643 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_566, %view_38), kwargs = {})
#   %convert_element_type_51 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_643, torch.float32), kwargs = {})
#   %pow_5 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_51, 2), kwargs = {})
#   %mean_4 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
#   %add_660 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
#   %rsqrt_4 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_660,), kwargs = {})
#   %mul_514 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_643, %rsqrt_4), kwargs = {})
#   %convert_element_type_52 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_514, torch.bfloat16), kwargs = {})
#   %mul_519 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_52, %arg25_1), kwargs = {})
#   return %buf46,%mul_519
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 34326528}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp3 * tmp3
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
    tmp7 = tl.where(xmask, tmp5, 0)
    tmp8 = tl.sum(tmp7, 1)[:, None].to(tl.float32)
    tmp9 = tl.full([1, 1], 1024.0, tl.float32)
    tmp10 = (tmp8 / tmp9)
    tmp11 = tl.full([1, 1], 1e-05, tl.float32)
    tmp12 = tmp10 + tmp11
    tmp13 = libdevice.rsqrt(tmp12)
    tmp14 = tmp3 * tmp13
    tmp15 = tmp14.to(tl.float32)
    tmp17 = tmp15 * tmp16
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp17, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qk/cqkbpgfajnhvo6khusrm2saj3wleafmwe7luhuvwpzqfdzaoeaip.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, to_22, pow_6, variance_5, add_16, rsqrt_5, mul_24, hidden_5, hidden_states_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_16 => add_898
#   attn_output_11 => view_50
#   hidden_5 => convert_element_type_66
#   hidden_states_11 => add_881
#   hidden_states_12 => mul_688
#   hidden_states_8 => view_38
#   hidden_states_9 => add_643
#   mul_24 => mul_683
#   pow_6 => pow_6
#   rsqrt_5 => rsqrt_5
#   to_22 => convert_element_type_65
#   variance_5 => mean_5
# Graph fragment:
#   %add_566 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_566]
#   %mm_13 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %mm_17 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %buf62 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf62]
#   %arg30_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg30_1]
#   %view_38 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_643 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_566, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_17, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_881 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_643, %view_50), kwargs = {})
#   %convert_element_type_65 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_881, torch.float32), kwargs = {})
#   %pow_6 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_65, 2), kwargs = {})
#   %mean_5 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_6, [-1], True), kwargs = {})
#   %add_898 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_5, 1e-05), kwargs = {})
#   %rsqrt_5 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_898,), kwargs = {})
#   %mul_683 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_881, %rsqrt_5), kwargs = {})
#   %convert_element_type_66 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_683, torch.bfloat16), kwargs = {})
#   %mul_688 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_66, %arg30_1), kwargs = {})
#   return %buf62,%mul_688
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 42907648}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp18 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp5 = tmp4.to(tl.float32)
    tmp6 = tmp5 * tmp5
    tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
    tmp9 = tl.where(xmask, tmp7, 0)
    tmp10 = tl.sum(tmp9, 1)[:, None].to(tl.float32)
    tmp11 = tl.full([1, 1], 1024.0, tl.float32)
    tmp12 = (tmp10 / tmp11)
    tmp13 = tl.full([1, 1], 1e-05, tl.float32)
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp5 * tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp19 = tmp17 * tmp18
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp19, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/3u/c3uazeq6a6bafcn2omitccrdw5m2ttwl3cs3u53lhtj65vokyxma.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_18 => add_975
#   attn_output_11 => view_50
#   hidden_6 => convert_element_type_76
#   hidden_states_11 => add_881
#   hidden_states_13 => view_56
#   hidden_states_14 => add_958
#   hidden_states_15 => mul_760
#   hidden_states_8 => view_38
#   hidden_states_9 => add_643
#   mul_27 => mul_755
#   pow_7 => pow_7
#   rsqrt_6 => rsqrt_6
#   to_24 => convert_element_type_75
#   variance_6 => mean_6
# Graph fragment:
#   %add_566 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_566]
#   %mm_13 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %mm_17 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %mm_20 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_20]
#   %buf68 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf68]
#   %arg34_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg34_1]
#   %view_38 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_643 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_566, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_17, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_881 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_643, %view_50), kwargs = {})
#   %view_56 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_20, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_958 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_881, %view_56), kwargs = {})
#   %convert_element_type_75 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_958, torch.float32), kwargs = {})
#   %pow_7 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_75, 2), kwargs = {})
#   %mean_6 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_7, [-1], True), kwargs = {})
#   %add_975 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_6, 1e-05), kwargs = {})
#   %rsqrt_6 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_975,), kwargs = {})
#   %mul_755 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_958, %rsqrt_6), kwargs = {})
#   %convert_element_type_76 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_755, torch.bfloat16), kwargs = {})
#   %mul_760 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_76, %arg34_1), kwargs = {})
#   return %buf68,%mul_760
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 51488768}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp20 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tmp7 * tmp7
    tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
    tmp11 = tl.where(xmask, tmp9, 0)
    tmp12 = tl.sum(tmp11, 1)[:, None].to(tl.float32)
    tmp13 = tl.full([1, 1], 1024.0, tl.float32)
    tmp14 = (tmp12 / tmp13)
    tmp15 = tl.full([1, 1], 1e-05, tl.float32)
    tmp16 = tmp14 + tmp15
    tmp17 = libdevice.rsqrt(tmp16)
    tmp18 = tmp7 * tmp17
    tmp19 = tmp18.to(tl.float32)
    tmp21 = tmp19 * tmp20
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp21, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qg/cqgtnfkxkmxf6anskcexgchucfih6c2sksriywewgnmczsymdx2p.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, attn_output_15, hidden_states_16, to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_22 => add_1213
#   attn_output_11 => view_50
#   attn_output_15 => view_68
#   hidden_7 => convert_element_type_90
#   hidden_states_11 => add_881
#   hidden_states_13 => view_56
#   hidden_states_14 => add_958
#   hidden_states_16 => add_1196
#   hidden_states_17 => mul_929
#   hidden_states_8 => view_38
#   hidden_states_9 => add_643
#   mul_33 => mul_924
#   pow_8 => pow_8
#   rsqrt_7 => rsqrt_7
#   to_30 => convert_element_type_89
#   variance_7 => mean_7
# Graph fragment:
#   %add_566 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_566]
#   %mm_13 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %mm_17 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %mm_20 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_20]
#   %mm_24 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_24]
#   %add_1196 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_1196]
#   %buf85 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf85]
#   %arg39_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg39_1]
#   %view_38 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_643 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_566, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_17, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_881 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_643, %view_50), kwargs = {})
#   %view_56 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_20, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_958 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_881, %view_56), kwargs = {})
#   %view_68 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_24, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_1196 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_958, %view_68), kwargs = {})
#   %convert_element_type_89 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1196, torch.float32), kwargs = {})
#   %pow_8 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_89, 2), kwargs = {})
#   %mean_7 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_8, [-1], True), kwargs = {})
#   %add_1213 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_7, 1e-05), kwargs = {})
#   %rsqrt_7 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_1213,), kwargs = {})
#   %mul_924 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_1196, %rsqrt_7), kwargs = {})
#   %convert_element_type_90 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_924, torch.bfloat16), kwargs = {})
#   %mul_929 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_90, %arg39_1), kwargs = {})
#   return %add_1196,%buf85,%mul_929
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 77232128}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp22 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp8 = tmp6 + tmp7
    tmp9 = tmp8.to(tl.float32)
    tmp10 = tmp9 * tmp9
    tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
    tmp13 = tl.where(xmask, tmp11, 0)
    tmp14 = tl.sum(tmp13, 1)[:, None].to(tl.float32)
    tmp15 = tl.full([1, 1], 1024.0, tl.float32)
    tmp16 = (tmp14 / tmp15)
    tmp17 = tl.full([1, 1], 1e-05, tl.float32)
    tmp18 = tmp16 + tmp17
    tmp19 = libdevice.rsqrt(tmp18)
    tmp20 = tmp9 * tmp19
    tmp21 = tmp20.to(tl.float32)
    tmp23 = tmp21 * tmp22
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp8, xmask)
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp23, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zu/czuaeqkzxt6khskjnidbw3lw53oujuqez7za6us4lmfqnogfjfwn.py
# Topologically Sorted Source Nodes: [hidden_states_58, hidden_states_59, to_96, pow_25, variance_24, add_72, rsqrt_24, mul_108, hidden_24, hidden_states_60], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_72 => add_3810
#   hidden_24 => convert_element_type_292
#   hidden_states_58 => view_218
#   hidden_states_59 => add_3793
#   hidden_states_60 => mul_2929
#   mul_108 => mul_2924
#   pow_25 => pow_25
#   rsqrt_24 => rsqrt_24
#   to_96 => convert_element_type_291
#   variance_24 => mean_24
# Graph fragment:
#   %add_3716 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_3716]
#   %mm_83 : Tensor "bf16[5*s27, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_83]
#   %buf271 : Tensor "f32[s27, 5, 1][5, 1, 5*s27]cuda:0" = PlaceHolder[target=buf271]
#   %arg115_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg115_1]
#   %view_218 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_83, [%arg0_1, 5, 1024]), kwargs = {})
#   %add_3793 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3716, %view_218), kwargs = {})
#   %convert_element_type_291 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3793, torch.float32), kwargs = {})
#   %pow_25 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_291, 2), kwargs = {})
#   %mean_24 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_25, [-1], True), kwargs = {})
#   %add_3810 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_24, 1e-05), kwargs = {})
#   %rsqrt_24 : Tensor "f32[s27, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_3810,), kwargs = {})
#   %mul_2924 : Tensor "f32[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_3793, %rsqrt_24), kwargs = {})
#   %convert_element_type_292 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_2924, torch.bfloat16), kwargs = {})
#   %mul_2929 : Tensor "bf16[s27, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_292, %arg115_1), kwargs = {})
#   return %buf271,%mul_2929
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8192, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 34326528}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp16 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp3 * tmp3
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
    tmp7 = tl.where(xmask, tmp5, 0)
    tmp8 = tl.sum(tmp7, 1)[:, None].to(tl.float32)
    tmp9 = tl.full([1, 1], 1024.0, tl.float32)
    tmp10 = (tmp8 / tmp9)
    tmp11 = tl.full([1, 1], 1e-05, tl.float32)
    tmp12 = tmp10 + tmp11
    tmp13 = libdevice.rsqrt(tmp12)
    tmp14 = tmp3 * tmp13
    tmp15 = tmp14.to(tl.float32)
    tmp17 = tmp15 * tmp16
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp17, xmask)
''', device_str='cuda')

def partition_0(args):
    arg3_1, arg1_1, arg2_1, arg4_1, arg7_1, arg8_1, arg9_1, arg10_1, arg5_1, arg6_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, s27 = args
    args.clear()
    s27 = s27
    assert_size_stride(arg3_1, (1024, ), (1, ))
    assert_size_stride(arg1_1, (1, s27, 4, 64), (256*s27, 256, 64, 1))
    assert_size_stride(arg2_1, (1024, 64), (64, 1))
    assert_size_stride(arg4_1, (1, 1, 1, 1024), (1024, 1024, 1024, 1))
    assert_size_stride(arg7_1, (1024, ), (1, ))
    assert_size_stride(arg8_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg9_1, (256, 1024), (1024, 1))
    assert_size_stride(arg10_1, (256, 1024), (1024, 1))
    assert_size_stride(arg5_1, (32768, 128), (128, 1))
    assert_size_stride(arg6_1, (32768, 128), (128, 1))
    assert_size_stride(arg11_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg12_1, (1024, ), (1, ))
    assert_size_stride(arg13_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg14_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg15_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg16_1, (1024, ), (1, ))
    assert_size_stride(arg17_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg18_1, (256, 1024), (1024, 1))
    assert_size_stride(arg19_1, (256, 1024), (1024, 1))
    assert_size_stride(arg20_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg21_1, (1024, ), (1, ))
    assert_size_stride(arg22_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg23_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg24_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg25_1, (1024, ), (1, ))
    assert_size_stride(arg26_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg27_1, (256, 1024), (1024, 1))
    assert_size_stride(arg28_1, (256, 1024), (1024, 1))
    assert_size_stride(arg29_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg30_1, (1024, ), (1, ))
    assert_size_stride(arg31_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg32_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg33_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg34_1, (1024, ), (1, ))
    assert_size_stride(arg35_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg36_1, (256, 1024), (1024, 1))
    assert_size_stride(arg37_1, (256, 1024), (1024, 1))
    assert_size_stride(arg38_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg39_1, (1024, ), (1, ))
    assert_size_stride(arg40_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg41_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg42_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg43_1, (1024, ), (1, ))
    assert_size_stride(arg44_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg45_1, (256, 1024), (1024, 1))
    assert_size_stride(arg46_1, (256, 1024), (1024, 1))
    assert_size_stride(arg47_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg48_1, (1024, ), (1, ))
    assert_size_stride(arg49_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg50_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg51_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg52_1, (1024, ), (1, ))
    assert_size_stride(arg53_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg54_1, (256, 1024), (1024, 1))
    assert_size_stride(arg55_1, (256, 1024), (1024, 1))
    assert_size_stride(arg56_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg57_1, (1024, ), (1, ))
    assert_size_stride(arg58_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg59_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg60_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg61_1, (1024, ), (1, ))
    assert_size_stride(arg62_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg63_1, (256, 1024), (1024, 1))
    assert_size_stride(arg64_1, (256, 1024), (1024, 1))
    assert_size_stride(arg65_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg66_1, (1024, ), (1, ))
    assert_size_stride(arg67_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg68_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg69_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg70_1, (1024, ), (1, ))
    assert_size_stride(arg71_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg72_1, (256, 1024), (1024, 1))
    assert_size_stride(arg73_1, (256, 1024), (1024, 1))
    assert_size_stride(arg74_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg75_1, (1024, ), (1, ))
    assert_size_stride(arg76_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg77_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg78_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg79_1, (1024, ), (1, ))
    assert_size_stride(arg80_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg81_1, (256, 1024), (1024, 1))
    assert_size_stride(arg82_1, (256, 1024), (1024, 1))
    assert_size_stride(arg83_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg84_1, (1024, ), (1, ))
    assert_size_stride(arg85_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg86_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg87_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg88_1, (1024, ), (1, ))
    assert_size_stride(arg89_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg90_1, (256, 1024), (1024, 1))
    assert_size_stride(arg91_1, (256, 1024), (1024, 1))
    assert_size_stride(arg92_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg93_1, (1024, ), (1, ))
    assert_size_stride(arg94_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg95_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg96_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg97_1, (1024, ), (1, ))
    assert_size_stride(arg98_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg99_1, (256, 1024), (1024, 1))
    assert_size_stride(arg100_1, (256, 1024), (1024, 1))
    assert_size_stride(arg101_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg102_1, (1024, ), (1, ))
    assert_size_stride(arg103_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg104_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg105_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg106_1, (1024, ), (1, ))
    assert_size_stride(arg107_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg108_1, (256, 1024), (1024, 1))
    assert_size_stride(arg109_1, (256, 1024), (1024, 1))
    assert_size_stride(arg110_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg111_1, (1024, ), (1, ))
    assert_size_stride(arg112_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg113_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg114_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg115_1, (1024, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((4*s27, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [x], Original ATen: [aten.view, aten.t, aten.addmm]
        extern_kernels.addmm(arg3_1, reinterpret_tensor(arg1_1, (4*s27, 64), (64, 1), 0), reinterpret_tensor(arg2_1, (64, 1024), (1, 64), 0), alpha=1, beta=1, out=buf0)
        del arg1_1
        del arg2_1
        del arg3_1
        buf2 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten.expand, aten.view, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
        triton_red_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0.run(arg4_1, buf0, arg7_1, buf2, triton_red_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0_xnumel, 1024, stream=stream0)
        del arg7_1
        buf3 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states, query_states], Original ATen: [aten.expand, aten.view, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg8_1, (1024, 2048), (1, 1024), 0), out=buf3)
        del arg8_1
        buf4 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg9_1, (1024, 256), (1, 1024), 0), out=buf4)
        del arg9_1
        buf5 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg10_1, (1024, 256), (1, 1024), 0), out=buf5)
        del arg10_1
        buf6 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf3, arg5_1, arg6_1, buf6, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf3
        buf7 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf4, arg5_1, arg6_1, buf7, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf8 = reinterpret_tensor(buf4, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf4  # reuse
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf5, buf8, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf5
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf9 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf6, buf7, buf8, scale=0.08838834764831843)
        del buf7
        del buf8
        buf10 = buf9[0]
        assert_size_stride(buf10, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf10, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf9
        buf15 = reinterpret_tensor(buf6, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf6  # reuse
        # Topologically Sorted Source Nodes: [transpose_3, attn_output_1], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf10, buf15, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf10
        buf16 = reinterpret_tensor(buf2, (5*s27, 1024), (1024, 1), 0); del buf2  # reuse
        # Topologically Sorted Source Nodes: [transpose_3, attn_output_1, attn_output_2, attn_output_3], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf15, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg11_1, (2048, 1024), (1, 2048), 0), out=buf16)
        del arg11_1
        del buf15
        buf18 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5.run(arg4_1, buf0, buf16, arg12_1, buf18, triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5_xnumel, 1024, stream=stream0)
        del arg12_1
        buf19 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_5], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf18, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg13_1, (1024, 4096), (1, 1024), 0), out=buf19)
        del arg13_1
        buf20 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_6], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf18, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg14_1, (1024, 4096), (1, 1024), 0), out=buf20)
        del arg14_1
        buf21 = reinterpret_tensor(buf19, (s27, 5, 4096), (20480, 4096, 1), 0); del buf19  # reuse
        # Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf21, buf20, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf20
        buf22 = reinterpret_tensor(buf18, (5*s27, 1024), (1024, 1), 0); del buf18  # reuse
        # Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8, hidden_states_3], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf21, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg15_1, (4096, 1024), (1, 4096), 0), out=buf22)
        del arg15_1
        buf24 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7.run(arg4_1, buf0, buf16, buf22, arg16_1, buf24, triton_red_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7_xnumel, 1024, stream=stream0)
        del arg16_1
        buf25 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_4], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf24, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg17_1, (1024, 2048), (1, 1024), 0), out=buf25)
        del arg17_1
        buf26 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_4], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf24, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg18_1, (1024, 256), (1, 1024), 0), out=buf26)
        del arg18_1
        buf27 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_3], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf24, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg19_1, (1024, 256), (1, 1024), 0), out=buf27)
        del arg19_1
        buf28 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf25, arg5_1, arg6_1, buf28, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf25
        buf29 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf26, arg5_1, arg6_1, buf29, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf30 = reinterpret_tensor(buf26, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf26  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf27, buf30, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf27
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf31 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf28, buf29, buf30, scale=0.08838834764831843)
        del buf29
        del buf30
        buf32 = buf31[0]
        assert_size_stride(buf32, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf32, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf31
        buf37 = reinterpret_tensor(buf28, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf28  # reuse
        # Topologically Sorted Source Nodes: [transpose_7, attn_output_5], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf32, buf37, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf32
        buf38 = reinterpret_tensor(buf24, (5*s27, 1024), (1024, 1), 0); del buf24  # reuse
        # Topologically Sorted Source Nodes: [transpose_7, attn_output_5, attn_output_6, attn_output_7], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf37, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg20_1, (2048, 1024), (1, 2048), 0), out=buf38)
        del arg20_1
        del buf37
        buf39 = reinterpret_tensor(buf16, (s27, 5, 1024), (5120, 1024, 1), 0); del buf16  # reuse
        buf41 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8.run(buf39, arg4_1, buf0, buf22, buf38, arg21_1, buf41, triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8_xnumel, 1024, stream=stream0)
        del arg21_1
        del arg4_1
        del buf0
        del buf22
        buf42 = reinterpret_tensor(buf21, (5*s27, 4096), (4096, 1), 0); del buf21  # reuse
        # Topologically Sorted Source Nodes: [to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7, linear_12], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf41, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg22_1, (1024, 4096), (1, 1024), 0), out=buf42)
        del arg22_1
        buf43 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_13], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf41, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg23_1, (1024, 4096), (1, 1024), 0), out=buf43)
        del arg23_1
        buf44 = reinterpret_tensor(buf42, (s27, 5, 4096), (20480, 4096, 1), 0); del buf42  # reuse
        # Topologically Sorted Source Nodes: [linear_12, silu_1, linear_13, mul_17], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf44, buf43, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf43
        buf45 = reinterpret_tensor(buf41, (5*s27, 1024), (1024, 1), 0); del buf41  # reuse
        # Topologically Sorted Source Nodes: [linear_12, silu_1, linear_13, mul_17, hidden_states_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf44, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg24_1, (4096, 1024), (1, 4096), 0), out=buf45)
        del arg24_1
        buf47 = reinterpret_tensor(buf38, (s27, 5, 1024), (5120, 1024, 1), 0); del buf38  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf39, buf45, arg25_1, buf47, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel, 1024, stream=stream0)
        del arg25_1
        buf48 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10, query_states_8], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg26_1, (1024, 2048), (1, 1024), 0), out=buf48)
        del arg26_1
        buf49 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_8], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg27_1, (1024, 256), (1, 1024), 0), out=buf49)
        del arg27_1
        buf50 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_6], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg28_1, (1024, 256), (1, 1024), 0), out=buf50)
        del arg28_1
        buf51 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf48, arg5_1, arg6_1, buf51, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf48
        buf52 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf49, arg5_1, arg6_1, buf52, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf53 = reinterpret_tensor(buf49, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf49  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf50, buf53, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf50
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf54 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf51, buf52, buf53, scale=0.08838834764831843)
        del buf52
        del buf53
        buf55 = buf54[0]
        assert_size_stride(buf55, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf55, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf54
        buf60 = reinterpret_tensor(buf51, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf51  # reuse
        # Topologically Sorted Source Nodes: [transpose_11, attn_output_9], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf55, buf60, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf55
        buf61 = reinterpret_tensor(buf47, (5*s27, 1024), (1024, 1), 0); del buf47  # reuse
        # Topologically Sorted Source Nodes: [transpose_11, attn_output_9, attn_output_10, attn_output_11], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf60, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg29_1, (2048, 1024), (1, 2048), 0), out=buf61)
        del arg29_1
        del buf60
        buf63 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, to_22, pow_6, variance_5, add_16, rsqrt_5, mul_24, hidden_5, hidden_states_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf39, buf45, buf61, arg30_1, buf63, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel, 1024, stream=stream0)
        del arg30_1
        buf64 = reinterpret_tensor(buf44, (5*s27, 4096), (4096, 1), 0); del buf44  # reuse
        # Topologically Sorted Source Nodes: [linear_19], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf63, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg31_1, (1024, 4096), (1, 1024), 0), out=buf64)
        del arg31_1
        buf65 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf63, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg32_1, (1024, 4096), (1, 1024), 0), out=buf65)
        del arg32_1
        del buf63
        buf66 = reinterpret_tensor(buf64, (s27, 5, 4096), (20480, 4096, 1), 0); del buf64  # reuse
        # Topologically Sorted Source Nodes: [linear_19, silu_2, linear_20, mul_26], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf66, buf65, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf65
        buf67 = empty_strided_cuda((5*s27, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_19, silu_2, linear_20, mul_26, hidden_states_13], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf66, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg33_1, (4096, 1024), (1, 4096), 0), out=buf67)
        del arg33_1
        buf69 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf39, buf45, buf61, buf67, arg34_1, buf69, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel, 1024, stream=stream0)
        del arg34_1
        buf70 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf69, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg35_1, (1024, 2048), (1, 1024), 0), out=buf70)
        del arg35_1
        buf71 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf69, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg36_1, (1024, 256), (1, 1024), 0), out=buf71)
        del arg36_1
        buf72 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_9], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf69, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg37_1, (1024, 256), (1, 1024), 0), out=buf72)
        del arg37_1
        buf73 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf70, arg5_1, arg6_1, buf73, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf70
        buf74 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf71, arg5_1, arg6_1, buf74, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf75 = reinterpret_tensor(buf71, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf71  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf72, buf75, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf72
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf76 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf73, buf74, buf75, scale=0.08838834764831843)
        del buf74
        del buf75
        buf77 = buf76[0]
        assert_size_stride(buf77, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf77, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf76
        buf82 = reinterpret_tensor(buf73, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf73  # reuse
        # Topologically Sorted Source Nodes: [transpose_15, attn_output_13], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf77, buf82, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf77
        buf83 = reinterpret_tensor(buf69, (5*s27, 1024), (1024, 1), 0); del buf69  # reuse
        # Topologically Sorted Source Nodes: [transpose_15, attn_output_13, attn_output_14, attn_output_15], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf82, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg38_1, (2048, 1024), (1, 2048), 0), out=buf83)
        del arg38_1
        del buf82
        buf84 = buf39; del buf39  # reuse
        buf86 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, attn_output_15, hidden_states_16, to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf84, buf45, buf61, buf67, buf83, arg39_1, buf86, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel, 1024, stream=stream0)
        del arg39_1
        del buf45
        del buf61
        del buf67
        buf87 = reinterpret_tensor(buf66, (5*s27, 4096), (4096, 1), 0); del buf66  # reuse
        # Topologically Sorted Source Nodes: [to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17, linear_26], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf86, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg40_1, (1024, 4096), (1, 1024), 0), out=buf87)
        del arg40_1
        buf88 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_27], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf86, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg41_1, (1024, 4096), (1, 1024), 0), out=buf88)
        del arg41_1
        buf89 = reinterpret_tensor(buf87, (s27, 5, 4096), (20480, 4096, 1), 0); del buf87  # reuse
        # Topologically Sorted Source Nodes: [linear_26, silu_3, linear_27, mul_35], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf89, buf88, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf88
        buf90 = reinterpret_tensor(buf86, (5*s27, 1024), (1024, 1), 0); del buf86  # reuse
        # Topologically Sorted Source Nodes: [linear_26, silu_3, linear_27, mul_35, hidden_states_18], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf89, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg42_1, (4096, 1024), (1, 4096), 0), out=buf90)
        del arg42_1
        buf92 = reinterpret_tensor(buf83, (s27, 5, 1024), (5120, 1024, 1), 0); del buf83  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, to_32, pow_9, variance_8, add_24, rsqrt_8, mul_36, hidden_8, hidden_states_20], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf84, buf90, arg43_1, buf92, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel, 1024, stream=stream0)
        del arg43_1
        buf93 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, to_32, pow_9, variance_8, add_24, rsqrt_8, mul_36, hidden_8, hidden_states_20, query_states_16], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg44_1, (1024, 2048), (1, 1024), 0), out=buf93)
        del arg44_1
        buf94 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_16], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg45_1, (1024, 256), (1, 1024), 0), out=buf94)
        del arg45_1
        buf95 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg46_1, (1024, 256), (1, 1024), 0), out=buf95)
        del arg46_1
        buf96 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf93, arg5_1, arg6_1, buf96, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf93
        buf97 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf94, arg5_1, arg6_1, buf97, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf98 = reinterpret_tensor(buf94, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf94  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf95, buf98, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf95
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf99 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf96, buf97, buf98, scale=0.08838834764831843)
        del buf97
        del buf98
        buf100 = buf99[0]
        assert_size_stride(buf100, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf100, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf99
        buf105 = reinterpret_tensor(buf96, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf96  # reuse
        # Topologically Sorted Source Nodes: [transpose_19, attn_output_17], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf100, buf105, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf100
        buf106 = reinterpret_tensor(buf92, (5*s27, 1024), (1024, 1), 0); del buf92  # reuse
        # Topologically Sorted Source Nodes: [transpose_19, attn_output_17, attn_output_18, attn_output_19], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf105, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg47_1, (2048, 1024), (1, 2048), 0), out=buf106)
        del arg47_1
        del buf105
        buf108 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, to_38, pow_10, variance_9, add_28, rsqrt_9, mul_42, hidden_9, hidden_states_22], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf84, buf90, buf106, arg48_1, buf108, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel, 1024, stream=stream0)
        del arg48_1
        buf109 = reinterpret_tensor(buf89, (5*s27, 4096), (4096, 1), 0); del buf89  # reuse
        # Topologically Sorted Source Nodes: [linear_33], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf108, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg49_1, (1024, 4096), (1, 1024), 0), out=buf109)
        del arg49_1
        buf110 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_34], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf108, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg50_1, (1024, 4096), (1, 1024), 0), out=buf110)
        del arg50_1
        del buf108
        buf111 = reinterpret_tensor(buf109, (s27, 5, 4096), (20480, 4096, 1), 0); del buf109  # reuse
        # Topologically Sorted Source Nodes: [linear_33, silu_4, linear_34, mul_44], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf111, buf110, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf110
        buf112 = empty_strided_cuda((5*s27, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_33, silu_4, linear_34, mul_44, hidden_states_23], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf111, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg51_1, (4096, 1024), (1, 4096), 0), out=buf112)
        del arg51_1
        buf114 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, to_40, pow_11, variance_10, add_30, rsqrt_10, mul_45, hidden_10, hidden_states_25], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf84, buf90, buf106, buf112, arg52_1, buf114, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel, 1024, stream=stream0)
        del arg52_1
        buf115 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf114, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg53_1, (1024, 2048), (1, 1024), 0), out=buf115)
        del arg53_1
        buf116 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf114, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg54_1, (1024, 256), (1, 1024), 0), out=buf116)
        del arg54_1
        buf117 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_15], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf114, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg55_1, (1024, 256), (1, 1024), 0), out=buf117)
        del arg55_1
        buf118 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf115, arg5_1, arg6_1, buf118, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf115
        buf119 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf116, arg5_1, arg6_1, buf119, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf120 = reinterpret_tensor(buf116, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf116  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf117, buf120, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf117
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf121 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf118, buf119, buf120, scale=0.08838834764831843)
        del buf119
        del buf120
        buf122 = buf121[0]
        assert_size_stride(buf122, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf122, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf121
        buf127 = reinterpret_tensor(buf118, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf118  # reuse
        # Topologically Sorted Source Nodes: [transpose_23, attn_output_21], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf122, buf127, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf122
        buf128 = reinterpret_tensor(buf114, (5*s27, 1024), (1024, 1), 0); del buf114  # reuse
        # Topologically Sorted Source Nodes: [transpose_23, attn_output_21, attn_output_22, attn_output_23], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf127, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg56_1, (2048, 1024), (1, 2048), 0), out=buf128)
        del arg56_1
        del buf127
        buf129 = buf84; del buf84  # reuse
        buf131 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, attn_output_23, hidden_states_26, to_46, pow_12, variance_11, add_34, rsqrt_11, mul_51, hidden_11, hidden_states_27], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf129, buf90, buf106, buf112, buf128, arg57_1, buf131, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel, 1024, stream=stream0)
        del arg57_1
        del buf106
        del buf112
        del buf128
        buf132 = reinterpret_tensor(buf111, (5*s27, 4096), (4096, 1), 0); del buf111  # reuse
        # Topologically Sorted Source Nodes: [to_46, pow_12, variance_11, add_34, rsqrt_11, mul_51, hidden_11, hidden_states_27, linear_40], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf131, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg58_1, (1024, 4096), (1, 1024), 0), out=buf132)
        del arg58_1
        buf133 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_41], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf131, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg59_1, (1024, 4096), (1, 1024), 0), out=buf133)
        del arg59_1
        buf134 = reinterpret_tensor(buf132, (s27, 5, 4096), (20480, 4096, 1), 0); del buf132  # reuse
        # Topologically Sorted Source Nodes: [linear_40, silu_5, linear_41, mul_53], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf134, buf133, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf133
        buf135 = reinterpret_tensor(buf131, (5*s27, 1024), (1024, 1), 0); del buf131  # reuse
        # Topologically Sorted Source Nodes: [linear_40, silu_5, linear_41, mul_53, hidden_states_28], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf134, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg60_1, (4096, 1024), (1, 4096), 0), out=buf135)
        del arg60_1
        buf137 = reinterpret_tensor(buf90, (s27, 5, 1024), (5120, 1024, 1), 0); del buf90  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, to_48, pow_13, variance_12, add_36, rsqrt_12, mul_54, hidden_12, hidden_states_30], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf129, buf135, arg61_1, buf137, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel, 1024, stream=stream0)
        del arg61_1
        buf138 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, to_48, pow_13, variance_12, add_36, rsqrt_12, mul_54, hidden_12, hidden_states_30, query_states_24], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg62_1, (1024, 2048), (1, 1024), 0), out=buf138)
        del arg62_1
        buf139 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg63_1, (1024, 256), (1, 1024), 0), out=buf139)
        del arg63_1
        buf140 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_18], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg64_1, (1024, 256), (1, 1024), 0), out=buf140)
        del arg64_1
        buf141 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf138, arg5_1, arg6_1, buf141, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf138
        buf142 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf139, arg5_1, arg6_1, buf142, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf143 = reinterpret_tensor(buf139, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf139  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf140, buf143, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf140
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf144 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf141, buf142, buf143, scale=0.08838834764831843)
        del buf142
        del buf143
        buf145 = buf144[0]
        assert_size_stride(buf145, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf145, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf144
        buf150 = reinterpret_tensor(buf141, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf141  # reuse
        # Topologically Sorted Source Nodes: [transpose_27, attn_output_25], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf145, buf150, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf145
        buf151 = reinterpret_tensor(buf137, (5*s27, 1024), (1024, 1), 0); del buf137  # reuse
        # Topologically Sorted Source Nodes: [transpose_27, attn_output_25, attn_output_26, attn_output_27], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf150, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg65_1, (2048, 1024), (1, 2048), 0), out=buf151)
        del arg65_1
        del buf150
        buf153 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, to_54, pow_14, variance_13, add_40, rsqrt_13, mul_60, hidden_13, hidden_states_32], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf129, buf135, buf151, arg66_1, buf153, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel, 1024, stream=stream0)
        del arg66_1
        buf154 = reinterpret_tensor(buf134, (5*s27, 4096), (4096, 1), 0); del buf134  # reuse
        # Topologically Sorted Source Nodes: [linear_47], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf153, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg67_1, (1024, 4096), (1, 1024), 0), out=buf154)
        del arg67_1
        buf155 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_48], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf153, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg68_1, (1024, 4096), (1, 1024), 0), out=buf155)
        del arg68_1
        del buf153
        buf156 = reinterpret_tensor(buf154, (s27, 5, 4096), (20480, 4096, 1), 0); del buf154  # reuse
        # Topologically Sorted Source Nodes: [linear_47, silu_6, linear_48, mul_62], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf156, buf155, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf155
        buf157 = empty_strided_cuda((5*s27, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_47, silu_6, linear_48, mul_62, hidden_states_33], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf156, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg69_1, (4096, 1024), (1, 4096), 0), out=buf157)
        del arg69_1
        buf159 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, to_56, pow_15, variance_14, add_42, rsqrt_14, mul_63, hidden_14, hidden_states_35], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf129, buf135, buf151, buf157, arg70_1, buf159, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel, 1024, stream=stream0)
        del arg70_1
        buf160 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_28], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf159, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg71_1, (1024, 2048), (1, 1024), 0), out=buf160)
        del arg71_1
        buf161 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_28], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf159, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg72_1, (1024, 256), (1, 1024), 0), out=buf161)
        del arg72_1
        buf162 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_21], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf159, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg73_1, (1024, 256), (1, 1024), 0), out=buf162)
        del arg73_1
        buf163 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf160, arg5_1, arg6_1, buf163, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf160
        buf164 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf161, arg5_1, arg6_1, buf164, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf165 = reinterpret_tensor(buf161, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf161  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf162, buf165, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf162
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf166 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf163, buf164, buf165, scale=0.08838834764831843)
        del buf164
        del buf165
        buf167 = buf166[0]
        assert_size_stride(buf167, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf167, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf166
        buf172 = reinterpret_tensor(buf163, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf163  # reuse
        # Topologically Sorted Source Nodes: [transpose_31, attn_output_29], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf167, buf172, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf167
        buf173 = reinterpret_tensor(buf159, (5*s27, 1024), (1024, 1), 0); del buf159  # reuse
        # Topologically Sorted Source Nodes: [transpose_31, attn_output_29, attn_output_30, attn_output_31], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf172, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg74_1, (2048, 1024), (1, 2048), 0), out=buf173)
        del arg74_1
        del buf172
        buf174 = buf129; del buf129  # reuse
        buf176 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, attn_output_31, hidden_states_36, to_62, pow_16, variance_15, add_46, rsqrt_15, mul_69, hidden_15, hidden_states_37], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf174, buf135, buf151, buf157, buf173, arg75_1, buf176, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel, 1024, stream=stream0)
        del arg75_1
        del buf135
        del buf151
        del buf157
        buf177 = reinterpret_tensor(buf156, (5*s27, 4096), (4096, 1), 0); del buf156  # reuse
        # Topologically Sorted Source Nodes: [to_62, pow_16, variance_15, add_46, rsqrt_15, mul_69, hidden_15, hidden_states_37, linear_54], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf176, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg76_1, (1024, 4096), (1, 1024), 0), out=buf177)
        del arg76_1
        buf178 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_55], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf176, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg77_1, (1024, 4096), (1, 1024), 0), out=buf178)
        del arg77_1
        buf179 = reinterpret_tensor(buf177, (s27, 5, 4096), (20480, 4096, 1), 0); del buf177  # reuse
        # Topologically Sorted Source Nodes: [linear_54, silu_7, linear_55, mul_71], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf179, buf178, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf178
        buf180 = reinterpret_tensor(buf176, (5*s27, 1024), (1024, 1), 0); del buf176  # reuse
        # Topologically Sorted Source Nodes: [linear_54, silu_7, linear_55, mul_71, hidden_states_38], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf179, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg78_1, (4096, 1024), (1, 4096), 0), out=buf180)
        del arg78_1
        buf182 = reinterpret_tensor(buf173, (s27, 5, 1024), (5120, 1024, 1), 0); del buf173  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, to_64, pow_17, variance_16, add_48, rsqrt_16, mul_72, hidden_16, hidden_states_40], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf174, buf180, arg79_1, buf182, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel, 1024, stream=stream0)
        del arg79_1
        buf183 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, to_64, pow_17, variance_16, add_48, rsqrt_16, mul_72, hidden_16, hidden_states_40, query_states_32], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg80_1, (1024, 2048), (1, 1024), 0), out=buf183)
        del arg80_1
        buf184 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_32], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg81_1, (1024, 256), (1, 1024), 0), out=buf184)
        del arg81_1
        buf185 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg82_1, (1024, 256), (1, 1024), 0), out=buf185)
        del arg82_1
        buf186 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf183, arg5_1, arg6_1, buf186, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf183
        buf187 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf184, arg5_1, arg6_1, buf187, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf188 = reinterpret_tensor(buf184, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf184  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf185, buf188, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf185
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf189 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf186, buf187, buf188, scale=0.08838834764831843)
        del buf187
        del buf188
        buf190 = buf189[0]
        assert_size_stride(buf190, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf190, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf189
        buf195 = reinterpret_tensor(buf186, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf186  # reuse
        # Topologically Sorted Source Nodes: [transpose_35, attn_output_33], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf190, buf195, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf190
        buf196 = reinterpret_tensor(buf182, (5*s27, 1024), (1024, 1), 0); del buf182  # reuse
        # Topologically Sorted Source Nodes: [transpose_35, attn_output_33, attn_output_34, attn_output_35], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf195, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg83_1, (2048, 1024), (1, 2048), 0), out=buf196)
        del arg83_1
        del buf195
        buf198 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, to_70, pow_18, variance_17, add_52, rsqrt_17, mul_78, hidden_17, hidden_states_42], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf174, buf180, buf196, arg84_1, buf198, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel, 1024, stream=stream0)
        del arg84_1
        buf199 = reinterpret_tensor(buf179, (5*s27, 4096), (4096, 1), 0); del buf179  # reuse
        # Topologically Sorted Source Nodes: [linear_61], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf198, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg85_1, (1024, 4096), (1, 1024), 0), out=buf199)
        del arg85_1
        buf200 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_62], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf198, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg86_1, (1024, 4096), (1, 1024), 0), out=buf200)
        del arg86_1
        del buf198
        buf201 = reinterpret_tensor(buf199, (s27, 5, 4096), (20480, 4096, 1), 0); del buf199  # reuse
        # Topologically Sorted Source Nodes: [linear_61, silu_8, linear_62, mul_80], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf201, buf200, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf200
        buf202 = empty_strided_cuda((5*s27, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_61, silu_8, linear_62, mul_80, hidden_states_43], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf201, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg87_1, (4096, 1024), (1, 4096), 0), out=buf202)
        del arg87_1
        buf204 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, to_72, pow_19, variance_18, add_54, rsqrt_18, mul_81, hidden_18, hidden_states_45], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf174, buf180, buf196, buf202, arg88_1, buf204, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel, 1024, stream=stream0)
        del arg88_1
        buf205 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_36], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf204, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg89_1, (1024, 2048), (1, 1024), 0), out=buf205)
        del arg89_1
        buf206 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_36], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf204, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg90_1, (1024, 256), (1, 1024), 0), out=buf206)
        del arg90_1
        buf207 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_27], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf204, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg91_1, (1024, 256), (1, 1024), 0), out=buf207)
        del arg91_1
        buf208 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf205, arg5_1, arg6_1, buf208, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf205
        buf209 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf206, arg5_1, arg6_1, buf209, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf210 = reinterpret_tensor(buf206, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf206  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf207, buf210, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf207
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf211 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf208, buf209, buf210, scale=0.08838834764831843)
        del buf209
        del buf210
        buf212 = buf211[0]
        assert_size_stride(buf212, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf212, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf211
        buf217 = reinterpret_tensor(buf208, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf208  # reuse
        # Topologically Sorted Source Nodes: [transpose_39, attn_output_37], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf212, buf217, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf212
        buf218 = reinterpret_tensor(buf204, (5*s27, 1024), (1024, 1), 0); del buf204  # reuse
        # Topologically Sorted Source Nodes: [transpose_39, attn_output_37, attn_output_38, attn_output_39], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf217, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg92_1, (2048, 1024), (1, 2048), 0), out=buf218)
        del arg92_1
        del buf217
        buf219 = buf174; del buf174  # reuse
        buf221 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, attn_output_39, hidden_states_46, to_78, pow_20, variance_19, add_58, rsqrt_19, mul_87, hidden_19, hidden_states_47], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf219, buf180, buf196, buf202, buf218, arg93_1, buf221, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel, 1024, stream=stream0)
        del arg93_1
        del buf180
        del buf196
        del buf202
        buf222 = reinterpret_tensor(buf201, (5*s27, 4096), (4096, 1), 0); del buf201  # reuse
        # Topologically Sorted Source Nodes: [to_78, pow_20, variance_19, add_58, rsqrt_19, mul_87, hidden_19, hidden_states_47, linear_68], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf221, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg94_1, (1024, 4096), (1, 1024), 0), out=buf222)
        del arg94_1
        buf223 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_69], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf221, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg95_1, (1024, 4096), (1, 1024), 0), out=buf223)
        del arg95_1
        buf224 = reinterpret_tensor(buf222, (s27, 5, 4096), (20480, 4096, 1), 0); del buf222  # reuse
        # Topologically Sorted Source Nodes: [linear_68, silu_9, linear_69, mul_89], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf224, buf223, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf223
        buf225 = reinterpret_tensor(buf221, (5*s27, 1024), (1024, 1), 0); del buf221  # reuse
        # Topologically Sorted Source Nodes: [linear_68, silu_9, linear_69, mul_89, hidden_states_48], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf224, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg96_1, (4096, 1024), (1, 4096), 0), out=buf225)
        del arg96_1
        buf227 = reinterpret_tensor(buf218, (s27, 5, 1024), (5120, 1024, 1), 0); del buf218  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, to_80, pow_21, variance_20, add_60, rsqrt_20, mul_90, hidden_20, hidden_states_50], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf219, buf225, arg97_1, buf227, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9_xnumel, 1024, stream=stream0)
        del arg97_1
        buf228 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, to_80, pow_21, variance_20, add_60, rsqrt_20, mul_90, hidden_20, hidden_states_50, query_states_40], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg98_1, (1024, 2048), (1, 1024), 0), out=buf228)
        del arg98_1
        buf229 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_40], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg99_1, (1024, 256), (1, 1024), 0), out=buf229)
        del arg99_1
        buf230 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_30], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg100_1, (1024, 256), (1, 1024), 0), out=buf230)
        del arg100_1
        buf231 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf228, arg5_1, arg6_1, buf231, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf228
        buf232 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf229, arg5_1, arg6_1, buf232, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        buf233 = reinterpret_tensor(buf229, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf229  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf230, buf233, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf230
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf234 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf231, buf232, buf233, scale=0.08838834764831843)
        del buf232
        del buf233
        buf235 = buf234[0]
        assert_size_stride(buf235, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf235, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf234
        buf240 = reinterpret_tensor(buf231, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf231  # reuse
        # Topologically Sorted Source Nodes: [transpose_43, attn_output_41], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf235, buf240, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf235
        buf241 = reinterpret_tensor(buf227, (5*s27, 1024), (1024, 1), 0); del buf227  # reuse
        # Topologically Sorted Source Nodes: [transpose_43, attn_output_41, attn_output_42, attn_output_43], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf240, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg101_1, (2048, 1024), (1, 2048), 0), out=buf241)
        del arg101_1
        del buf240
        buf243 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, to_86, pow_22, variance_21, add_64, rsqrt_21, mul_96, hidden_21, hidden_states_52], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf219, buf225, buf241, arg102_1, buf243, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10_xnumel, 1024, stream=stream0)
        del arg102_1
        buf244 = reinterpret_tensor(buf224, (5*s27, 4096), (4096, 1), 0); del buf224  # reuse
        # Topologically Sorted Source Nodes: [linear_75], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf243, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg103_1, (1024, 4096), (1, 1024), 0), out=buf244)
        del arg103_1
        buf245 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_76], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf243, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg104_1, (1024, 4096), (1, 1024), 0), out=buf245)
        del arg104_1
        del buf243
        buf246 = reinterpret_tensor(buf244, (s27, 5, 4096), (20480, 4096, 1), 0); del buf244  # reuse
        # Topologically Sorted Source Nodes: [linear_75, silu_10, linear_76, mul_98], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf246, buf245, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf245
        buf247 = empty_strided_cuda((5*s27, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_75, silu_10, linear_76, mul_98, hidden_states_53], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf246, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg105_1, (4096, 1024), (1, 4096), 0), out=buf247)
        del arg105_1
        buf249 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, to_88, pow_23, variance_22, add_66, rsqrt_22, mul_99, hidden_22, hidden_states_55], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf219, buf225, buf241, buf247, arg106_1, buf249, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11_xnumel, 1024, stream=stream0)
        del arg106_1
        buf250 = empty_strided_cuda((5*s27, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_44], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf249, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg107_1, (1024, 2048), (1, 1024), 0), out=buf250)
        del arg107_1
        buf251 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_44], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf249, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg108_1, (1024, 256), (1, 1024), 0), out=buf251)
        del arg108_1
        buf252 = empty_strided_cuda((5*s27, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_33], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf249, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg109_1, (1024, 256), (1, 1024), 0), out=buf252)
        del arg109_1
        buf253 = empty_strided_cuda((s27, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf250, arg5_1, arg6_1, buf253, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1_xnumel, stream=stream0)
        del buf250
        buf254 = empty_strided_cuda((s27, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf251, arg5_1, arg6_1, buf254, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2_xnumel, stream=stream0)
        del arg5_1
        del arg6_1
        buf255 = reinterpret_tensor(buf251, (s27, 2, 5, 128), (1280, 640, 128, 1), 0); del buf251  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel = 1280*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf252, buf255, triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3_xnumel, stream=stream0)
        del buf252
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf256 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf253, buf254, buf255, scale=0.08838834764831843)
        del buf254
        del buf255
        buf257 = buf256[0]
        assert_size_stride(buf257, (s27, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf257, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf256
        buf262 = reinterpret_tensor(buf253, (s27, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf253  # reuse
        # Topologically Sorted Source Nodes: [transpose_47, attn_output_45], Original ATen: [aten.transpose, aten.clone]
        triton_poi_fused_clone_transpose_4_xnumel = 10240*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf257, buf262, triton_poi_fused_clone_transpose_4_xnumel, stream=stream0)
        del buf257
        buf263 = reinterpret_tensor(buf249, (5*s27, 1024), (1024, 1), 0); del buf249  # reuse
        # Topologically Sorted Source Nodes: [transpose_47, attn_output_45, attn_output_46, attn_output_47], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf262, (5*s27, 2048), (2048, 1), 0), reinterpret_tensor(arg110_1, (2048, 1024), (1, 2048), 0), out=buf263)
        del arg110_1
        del buf262
        buf264 = buf219; del buf219  # reuse
        buf266 = empty_strided_cuda((s27, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, attn_output_47, hidden_states_56, to_94, pow_24, variance_23, add_70, rsqrt_23, mul_105, hidden_23, hidden_states_57], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf264, buf225, buf241, buf247, buf263, arg111_1, buf266, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12_xnumel, 1024, stream=stream0)
        del arg111_1
        del buf225
        del buf241
        del buf247
        del buf263
        buf267 = reinterpret_tensor(buf246, (5*s27, 4096), (4096, 1), 0); del buf246  # reuse
        # Topologically Sorted Source Nodes: [to_94, pow_24, variance_23, add_70, rsqrt_23, mul_105, hidden_23, hidden_states_57, linear_82], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf266, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg112_1, (1024, 4096), (1, 1024), 0), out=buf267)
        del arg112_1
        buf268 = empty_strided_cuda((5*s27, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_83], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf266, (5*s27, 1024), (1024, 1), 0), reinterpret_tensor(arg113_1, (1024, 4096), (1, 1024), 0), out=buf268)
        del arg113_1
        buf269 = reinterpret_tensor(buf267, (s27, 5, 4096), (20480, 4096, 1), 0); del buf267  # reuse
        # Topologically Sorted Source Nodes: [linear_82, silu_11, linear_83, mul_107], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        triton_poi_fused__unsafe_view_mul_silu_6_xnumel = 20480*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf269, buf268, triton_poi_fused__unsafe_view_mul_silu_6_xnumel, stream=stream0)
        del buf268
        buf270 = reinterpret_tensor(buf266, (5*s27, 1024), (1024, 1), 0); del buf266  # reuse
        # Topologically Sorted Source Nodes: [linear_82, silu_11, linear_83, mul_107, hidden_states_58], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf269, (5*s27, 4096), (4096, 1), 0), reinterpret_tensor(arg114_1, (4096, 1024), (1, 4096), 0), out=buf270)
        del arg114_1
        del buf269
        buf272 = buf264; del buf264  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_58, hidden_states_59, to_96, pow_25, variance_24, add_72, rsqrt_24, mul_108, hidden_24, hidden_states_60], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13_xnumel = 5*s27
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13.run(buf272, buf270, arg115_1, triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13_xnumel, 1024, stream=stream0)
        del arg115_1
        del buf270
    return (buf272, )


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1 = args
        args.clear()
        s27 = arg0_1
        partition0_args = [arg3_1, arg1_1, arg2_1, arg4_1, arg7_1, arg8_1, arg9_1, arg10_1, arg5_1, arg6_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, s27]
        del arg3_1, arg1_1, arg2_1, arg4_1, arg7_1, arg8_1, arg9_1, arg10_1, arg5_1, arg6_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1
        (buf272,) = self.partitions[0](partition0_args)
        del partition0_args
        return (reinterpret_tensor(buf272, (1, s27, 1024), (5120*s27, 5120, 1), 0), )

runner = Runner(partitions=[partition_0,])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = 838
    arg1_1 = rand_strided((1, 838, 4, 64), (214528, 256, 64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((1024, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg4_1 = rand_strided((1, 1, 1, 1024), (1024, 1024, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg5_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg6_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg7_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg8_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg9_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg10_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg11_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg12_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg13_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg14_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg15_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg16_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg17_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg18_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg19_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg20_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg21_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg22_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg23_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg24_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg25_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg26_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg27_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg28_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg29_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg30_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg31_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg32_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg33_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg34_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg35_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg36_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg37_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg38_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg39_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg40_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg41_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg42_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg43_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg44_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg45_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg46_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg47_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg48_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg49_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg50_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg51_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg52_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg53_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg54_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg55_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg56_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg57_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg58_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg59_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg60_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg61_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg62_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg63_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg64_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg65_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg66_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg67_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg68_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg69_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg70_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg71_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg72_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg73_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg74_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg75_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg76_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg77_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg78_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg79_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg80_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg81_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg82_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg83_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg84_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg85_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg86_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg87_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg88_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg89_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg90_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg91_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg92_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg93_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg94_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg95_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg96_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg97_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg98_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg99_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg100_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg101_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg102_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg103_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg104_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg105_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg106_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg107_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg108_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg109_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg110_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg111_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg112_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg113_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg114_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg115_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
