# AOT ID: ['2_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
from torch._C import _cuda_getCurrentRawStream as get_raw_stream



# kernel path: /tmp/torchinductor_shingokuga/tu/ctuxkzqr3gduxjb5ngcbwjluoji6incs7gzvxf4vrmxti3di4plm.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten.expand, aten._unsafe_view, aten.add, aten.cat, aten.view, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add => add_1
#   hidden => convert_element_type_3
#   hidden_states => mul_1
#   mul => mul
#   pow_1 => pow_1
#   rsqrt => rsqrt
#   special_tokens => expand
#   to => convert_element_type_2
#   variance => mean
#   x => add, view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %mm : Tensor "bf16[4, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm]
#   %arg2_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg2_1]
#   %buf1 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf1]
#   %arg6_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg6_1]
#   %expand : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [1, 1, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, 1, 4, 1024][4096, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [1, 1, 4, 1024]), kwargs = {})
#   %add : Tensor "bf16[1, 1, 4, 1024][4096, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_1, %arg2_1), kwargs = {})
#   %cat : Tensor "bf16[1, 1, 5, 1024][5120, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %add], 2), kwargs = {})
#   %view_2 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [1, 5, 1024]), kwargs = {})
#   %convert_element_type_2 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_2, torch.float32), kwargs = {})
#   %pow_1 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_2, 2), kwargs = {})
#   %mean : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
#   %add_1 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
#   %rsqrt : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_1,), kwargs = {})
#   %mul : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_2, %rsqrt), kwargs = {})
#   %convert_element_type_3 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
#   %mul_1 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_3, %arg6_1), kwargs = {})
#   return %buf1,%mul_1
triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_0 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 34816}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = xindex
    r0_1 = r0_index
    tmp28 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_1, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_1 + 1024*((-1) + x0)), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (tl.broadcast_to(r0_1, [XBLOCK, R0_BLOCK])), tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp9 + tmp10
    tmp12 = tl.full(tmp11.shape, 0.0, tmp11.dtype)
    tmp13 = tl.where(tmp6, tmp11, tmp12)
    tmp14 = tl.where(tmp4, tmp5, tmp13)
    tmp15 = tmp14.to(tl.float32)
    tmp16 = tmp15 * tmp15
    tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
    tmp19 = tl.where(xmask, tmp17, 0)
    tmp20 = tl.sum(tmp19, 1)[:, None].to(tl.float32)
    tmp21 = tl.full([1, 1], 1024.0, tl.float32)
    tmp22 = (tmp20 / tmp21)
    tmp23 = tl.full([1, 1], 1e-05, tl.float32)
    tmp24 = tmp22 + tmp23
    tmp25 = libdevice.rsqrt(tmp24)
    tmp26 = tmp15 * tmp25
    tmp27 = tmp26.to(tl.float32)
    tmp29 = tmp27 * tmp28
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp29, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/cr/ccr56jaygxchnuqxrw3p54ci2jkwinjlb2puq2vm77ynaa67qdsm.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_11
#   k_embed => add_3
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_13
#   key_states_3 => clone_1
#   mul_2 => mul_2
#   mul_3 => mul_3
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_10
#   q_embed => add_2
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_12
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm_1 : Tensor "bf16[5, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_1]
#   %arg4_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %arg5_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %view_4 : Tensor "bf16[1, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[1, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [1, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_10 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_10, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[1, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [1, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[1, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [1, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg4_1, [%iota]), kwargs = {})
#   %mul_2 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %index), kwargs = {})
#   %neg : Tensor "f32[1, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_3 : Tensor "f32[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
#   %convert_element_type_12 : Tensor "bf16[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_12,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_4 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg_1 : Tensor "f32[1, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_3 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[1, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [1, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[1, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [1, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf6
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16384}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 104960}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 10240
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x3 = xindex
    x0 = (xindex % 128)
    x2 = xindex // 2048
    x4 = xindex // 128
    x1 = ((xindex // 128) % 16)
    tmp0 = tl.load(in_ptr0 + (x3), xmask).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x4 + (x0)), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x4 + ((-64) + x0)), tmp15 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 640*x1), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qt/cqt3rtv2korf3awge5kibuixgf2grh5ghc5gkgr7cvwh5rok7qxt.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_11
#   k_embed => add_3
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_13
#   key_states_3 => clone_1
#   mul_2 => mul_2
#   mul_3 => mul_3
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_10
#   q_embed => add_2
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_12
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm_2 : Tensor "bf16[5, 256][256, 1]cuda:0" = PlaceHolder[target=mm_2]
#   %arg4_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %arg5_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %view_4 : Tensor "bf16[1, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[1, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [1, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_10 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_10, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[1, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [1, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[1, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [1, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg4_1, [%iota]), kwargs = {})
#   %mul_2 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %index), kwargs = {})
#   %neg : Tensor "f32[1, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_3 : Tensor "f32[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
#   %convert_element_type_12 : Tensor "bf16[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_12,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_4 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg_1 : Tensor "f32[1, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_3 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[1, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [1, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[1, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [1, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf7
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 15360}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1280
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x3 = xindex
    x0 = (xindex % 128)
    x2 = xindex // 256
    x4 = xindex // 128
    x1 = ((xindex // 128) % 2)
    tmp0 = tl.load(in_ptr0 + (x3), xmask).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x4 + (x0)), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x4 + ((-64) + x0)), tmp15 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 640*x1), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4x/c4xfez5xxf56eax2gfz6gbxskbozyexvv4nfoqkducq3tdn27lc7.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_11
#   k_embed => add_3
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_13
#   key_states_3 => clone_1
#   mul_2 => mul_2
#   mul_3 => mul_3
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_10
#   q_embed => add_2
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_12
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm_3 : Tensor "bf16[5, 256][256, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %view_4 : Tensor "bf16[1, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[1, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [1, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_10 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_10, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[1, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [1, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[1, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [1, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg4_1, [%iota]), kwargs = {})
#   %mul_2 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %index), kwargs = {})
#   %neg : Tensor "f32[1, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_3 : Tensor "f32[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
#   %convert_element_type_12 : Tensor "bf16[1, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[1, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_12,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_4 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg_1 : Tensor "f32[1, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_3 : Tensor "f32[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[1, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [1, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[1, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [1, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[1, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[1, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf8
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 7680}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1280
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 5)
    x2 = xindex // 640
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 256*x1), xmask).to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5x/c5xdlbpgypkcl4bnvqgqg55aojh4lyivnud5xxrhd4m6jqrxenna.py
# Topologically Sorted Source Nodes: [transpose_3, attn_output_1], Original ATen: [aten.transpose, aten.clone]
# Source node to ATen node mapping:
#   attn_output_1 => clone_3
#   transpose_3 => permute_7
# Graph fragment:
#   %getitem_4 : Tensor "bf16[1, 16, 5, 128][10240, 640, 128, 1]cuda:0" = PlaceHolder[target=getitem_4]
#   %permute_7 : Tensor "bf16[1, 5, 16, 128][10240, 128, 640, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%getitem_4, [0, 2, 1, 3]), kwargs = {})
#   %clone_3 : Tensor "bf16[1, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_7,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_3
triton_poi_fused_clone_transpose_4 = async_compile.triton('triton_poi_fused_clone_transpose_4', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16384}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_transpose_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 61440}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_transpose_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 10240
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 16)
    x2 = xindex // 2048
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 640*x1), xmask).to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/gu/cguyxu3goer7xnk7o7m6algn47wlq74pr3sqawo7syogrij3kkrr.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.expand, aten._unsafe_view, aten.add, aten.cat, aten.view, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_4 => add_5
#   attn_output_3 => view_14
#   hidden_1 => convert_element_type_17
#   hidden_states_1 => add_4
#   hidden_states_2 => mul_7
#   mul_6 => mul_6
#   pow_2 => pow_2
#   rsqrt_1 => rsqrt_1
#   special_tokens => expand
#   to_6 => convert_element_type_16
#   variance_1 => mean_1
#   x => add, view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %mm : Tensor "bf16[4, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm]
#   %arg2_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg2_1]
#   %mm_4 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_4]
#   %buf17 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf17]
#   %arg11_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg11_1]
#   %expand : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [1, 1, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, 1, 4, 1024][4096, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [1, 1, 4, 1024]), kwargs = {})
#   %add : Tensor "bf16[1, 1, 4, 1024][4096, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_1, %arg2_1), kwargs = {})
#   %cat : Tensor "bf16[1, 1, 5, 1024][5120, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %add], 2), kwargs = {})
#   %view_2 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [1, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_4, [1, 5, 1024]), kwargs = {})
#   %add_4 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %convert_element_type_16 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_4, torch.float32), kwargs = {})
#   %pow_2 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_16, 2), kwargs = {})
#   %mean_1 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
#   %add_5 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
#   %rsqrt_1 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_5,), kwargs = {})
#   %mul_6 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_4, %rsqrt_1), kwargs = {})
#   %convert_element_type_17 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {})
#   %mul_7 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_17, %arg11_1), kwargs = {})
#   return %buf17,%mul_7
triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 45056}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = xindex
    r0_1 = r0_index
    tmp15 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp30 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_1, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_1 + 1024*((-1) + x0)), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (tl.broadcast_to(r0_1, [XBLOCK, R0_BLOCK])), tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp9 + tmp10
    tmp12 = tl.full(tmp11.shape, 0.0, tmp11.dtype)
    tmp13 = tl.where(tmp6, tmp11, tmp12)
    tmp14 = tl.where(tmp4, tmp5, tmp13)
    tmp16 = tmp14 + tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp18 = tmp17 * tmp17
    tmp19 = tl.broadcast_to(tmp18, [XBLOCK, R0_BLOCK])
    tmp21 = tl.where(xmask, tmp19, 0)
    tmp22 = tl.sum(tmp21, 1)[:, None].to(tl.float32)
    tmp23 = tl.full([1, 1], 1024.0, tl.float32)
    tmp24 = (tmp22 / tmp23)
    tmp25 = tl.full([1, 1], 1e-05, tl.float32)
    tmp26 = tmp24 + tmp25
    tmp27 = libdevice.rsqrt(tmp26)
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28.to(tl.float32)
    tmp31 = tmp29 * tmp30
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp31, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5w/c5wy2tm2dt5664afcpdw4g62suckcblzzttbu3itwusdj7dbpjba.py
# Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
# Source node to ATen node mapping:
#   linear_5 => view_16
#   linear_6 => view_18
#   mul_8 => mul_8
#   silu => add_6, convert_element_type_20, convert_element_type_21, div, exp, neg_2
# Graph fragment:
#   %mm_5 : Tensor "bf16[5, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_5]
#   %mm_6 : Tensor "bf16[5, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %view_16 : Tensor "bf16[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_5, [1, 5, 4096]), kwargs = {})
#   %convert_element_type_20 : Tensor "f32[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_16, torch.float32), kwargs = {})
#   %neg_2 : Tensor "f32[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%convert_element_type_20,), kwargs = {})
#   %exp : Tensor "f32[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%neg_2,), kwargs = {})
#   %add_6 : Tensor "f32[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%exp, 1), kwargs = {})
#   %div : Tensor "f32[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_20, %add_6), kwargs = {})
#   %convert_element_type_21 : Tensor "bf16[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%div, torch.bfloat16), kwargs = {})
#   %view_18 : Tensor "bf16[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [1, 5, 4096]), kwargs = {})
#   %mul_8 : Tensor "bf16[1, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_21, %view_18), kwargs = {})
#   return %mul_8
triton_poi_fused__unsafe_view_mul_silu_6 = async_compile.triton('triton_poi_fused__unsafe_view_mul_silu_6', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 32768}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__unsafe_view_mul_silu_6', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 163840}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__unsafe_view_mul_silu_6(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 20480
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = -tmp1
    tmp3 = libdevice.exp(tmp2)
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = (tmp1 / tmp5)
    tmp7 = tmp6.to(tl.float32)
    tmp9 = tmp7 * tmp8
    tl.store(in_out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vf/cvfhdiwnec5jwzhz75nuj6zurtotmfkfjxlftig7t6atknvydhla.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.expand, aten._unsafe_view, aten.add, aten.cat, aten.view, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_6 => add_8
#   attn_output_3 => view_14
#   hidden_2 => convert_element_type_27
#   hidden_states_1 => add_4
#   hidden_states_3 => view_20
#   hidden_states_4 => add_7
#   hidden_states_5 => mul_10
#   mul_9 => mul_9
#   pow_3 => pow_3
#   rsqrt_2 => rsqrt_2
#   special_tokens => expand
#   to_8 => convert_element_type_26
#   variance_2 => mean_2
#   x => add, view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %mm : Tensor "bf16[4, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm]
#   %arg2_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg2_1]
#   %mm_4 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_4]
#   %mm_7 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_7]
#   %add_7 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_7]
#   %buf24 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf24]
#   %arg15_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg15_1]
#   %expand : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [1, 1, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, 1, 4, 1024][4096, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [1, 1, 4, 1024]), kwargs = {})
#   %add : Tensor "bf16[1, 1, 4, 1024][4096, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_1, %arg2_1), kwargs = {})
#   %cat : Tensor "bf16[1, 1, 5, 1024][5120, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %add], 2), kwargs = {})
#   %view_2 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [1, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_4, [1, 5, 1024]), kwargs = {})
#   %add_4 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %view_20 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_7, [1, 5, 1024]), kwargs = {})
#   %add_7 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %view_20), kwargs = {})
#   %convert_element_type_26 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.float32), kwargs = {})
#   %pow_3 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_26, 2), kwargs = {})
#   %mean_2 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
#   %add_8 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
#   %rsqrt_2 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {})
#   %mul_9 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_7, %rsqrt_2), kwargs = {})
#   %convert_element_type_27 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_9, torch.bfloat16), kwargs = {})
#   %mul_10 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_27, %arg15_1), kwargs = {})
#   return %add_7,%buf24,%mul_10
triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 75776}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = xindex
    r0_1 = r0_index
    tmp15 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp17 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp32 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_1, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_1 + 1024*((-1) + x0)), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (tl.broadcast_to(r0_1, [XBLOCK, R0_BLOCK])), tmp6 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp9 + tmp10
    tmp12 = tl.full(tmp11.shape, 0.0, tmp11.dtype)
    tmp13 = tl.where(tmp6, tmp11, tmp12)
    tmp14 = tl.where(tmp4, tmp5, tmp13)
    tmp16 = tmp14 + tmp15
    tmp18 = tmp16 + tmp17
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tmp19 * tmp19
    tmp21 = tl.broadcast_to(tmp20, [XBLOCK, R0_BLOCK])
    tmp23 = tl.where(xmask, tmp21, 0)
    tmp24 = tl.sum(tmp23, 1)[:, None].to(tl.float32)
    tmp25 = tl.full([1, 1], 1024.0, tl.float32)
    tmp26 = (tmp24 / tmp25)
    tmp27 = tl.full([1, 1], 1e-05, tl.float32)
    tmp28 = tmp26 + tmp27
    tmp29 = libdevice.rsqrt(tmp28)
    tmp30 = tmp19 * tmp29
    tmp31 = tmp30.to(tl.float32)
    tmp33 = tmp31 * tmp32
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp18, xmask)
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp33, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yg/cyg57cvideor6efyro77yjdtuaby52ksslhbaehsikh3n4fisjkt.py
# Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_10 => add_12
#   attn_output_7 => view_32
#   hidden_3 => convert_element_type_41
#   hidden_states_6 => add_11
#   hidden_states_7 => mul_16
#   mul_15 => mul_15
#   pow_4 => pow_4
#   rsqrt_3 => rsqrt_3
#   to_14 => convert_element_type_40
#   variance_3 => mean_3
# Graph fragment:
#   %add_7 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_7]
#   %mm_11 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_11]
#   %buf40 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf40]
#   %arg20_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg20_1]
#   %view_32 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_11, [1, 5, 1024]), kwargs = {})
#   %add_11 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_32), kwargs = {})
#   %convert_element_type_40 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {})
#   %pow_4 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_40, 2), kwargs = {})
#   %mean_3 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
#   %add_12 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
#   %rsqrt_3 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {})
#   %mul_15 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_11, %rsqrt_3), kwargs = {})
#   %convert_element_type_41 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_15, torch.bfloat16), kwargs = {})
#   %mul_16 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_41, %arg20_1), kwargs = {})
#   return %buf40,%mul_16
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 43008}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp3 * tmp3
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
    tmp7 = tl.where(xmask, tmp5, 0)
    tmp8 = tl.sum(tmp7, 1)[:, None].to(tl.float32)
    tmp9 = tl.full([1, 1], 1024.0, tl.float32)
    tmp10 = (tmp8 / tmp9)
    tmp11 = tl.full([1, 1], 1e-05, tl.float32)
    tmp12 = tmp10 + tmp11
    tmp13 = libdevice.rsqrt(tmp12)
    tmp14 = tmp3 * tmp13
    tmp15 = tmp14.to(tl.float32)
    tmp17 = tmp15 * tmp16
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp17, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/nz/cnzp7txoqdpzusqsr4ds3jn7bmje2766mmompg5fdcpj62hjtywl.py
# Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_12 => add_15
#   attn_output_7 => view_32
#   hidden_4 => convert_element_type_51
#   hidden_states_10 => mul_19
#   hidden_states_6 => add_11
#   hidden_states_8 => view_38
#   hidden_states_9 => add_14
#   mul_18 => mul_18
#   pow_5 => pow_5
#   rsqrt_4 => rsqrt_4
#   to_16 => convert_element_type_50
#   variance_4 => mean_4
# Graph fragment:
#   %add_7 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_7]
#   %mm_11 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_11]
#   %mm_14 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_14]
#   %buf46 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf46]
#   %arg24_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg24_1]
#   %view_32 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_11, [1, 5, 1024]), kwargs = {})
#   %add_11 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_32), kwargs = {})
#   %view_38 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_14, [1, 5, 1024]), kwargs = {})
#   %add_14 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_38), kwargs = {})
#   %convert_element_type_50 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_14, torch.float32), kwargs = {})
#   %pow_5 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_50, 2), kwargs = {})
#   %mean_4 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
#   %add_15 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
#   %rsqrt_4 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_15,), kwargs = {})
#   %mul_18 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_14, %rsqrt_4), kwargs = {})
#   %convert_element_type_51 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_18, torch.bfloat16), kwargs = {})
#   %mul_19 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_51, %arg24_1), kwargs = {})
#   return %buf46,%mul_19
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 53248}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp18 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp5 = tmp4.to(tl.float32)
    tmp6 = tmp5 * tmp5
    tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
    tmp9 = tl.where(xmask, tmp7, 0)
    tmp10 = tl.sum(tmp9, 1)[:, None].to(tl.float32)
    tmp11 = tl.full([1, 1], 1024.0, tl.float32)
    tmp12 = (tmp10 / tmp11)
    tmp13 = tl.full([1, 1], 1e-05, tl.float32)
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp5 * tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp19 = tmp17 * tmp18
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp19, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/d2/cd26h53mowkucxzvzqnixiroaj4dawveylnf62dqbkaicebqrvgg.py
# Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, to_22, pow_6, variance_5, add_16, rsqrt_5, mul_24, hidden_5, hidden_states_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_16 => add_19
#   attn_output_11 => view_50
#   attn_output_7 => view_32
#   hidden_5 => convert_element_type_65
#   hidden_states_11 => add_18
#   hidden_states_12 => mul_25
#   hidden_states_6 => add_11
#   hidden_states_8 => view_38
#   hidden_states_9 => add_14
#   mul_24 => mul_24
#   pow_6 => pow_6
#   rsqrt_5 => rsqrt_5
#   to_22 => convert_element_type_64
#   variance_5 => mean_5
# Graph fragment:
#   %add_7 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_7]
#   %mm_11 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_11]
#   %mm_14 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_14]
#   %mm_18 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_18]
#   %buf62 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf62]
#   %arg29_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg29_1]
#   %view_32 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_11, [1, 5, 1024]), kwargs = {})
#   %add_11 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_32), kwargs = {})
#   %view_38 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_14, [1, 5, 1024]), kwargs = {})
#   %add_14 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_18, [1, 5, 1024]), kwargs = {})
#   %add_18 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_14, %view_50), kwargs = {})
#   %convert_element_type_64 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_18, torch.float32), kwargs = {})
#   %pow_6 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_64, 2), kwargs = {})
#   %mean_5 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_6, [-1], True), kwargs = {})
#   %add_19 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_5, 1e-05), kwargs = {})
#   %rsqrt_5 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_19,), kwargs = {})
#   %mul_24 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_18, %rsqrt_5), kwargs = {})
#   %convert_element_type_65 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_24, torch.bfloat16), kwargs = {})
#   %mul_25 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_65, %arg29_1), kwargs = {})
#   return %buf62,%mul_25
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 63488}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp20 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tmp7 * tmp7
    tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
    tmp11 = tl.where(xmask, tmp9, 0)
    tmp12 = tl.sum(tmp11, 1)[:, None].to(tl.float32)
    tmp13 = tl.full([1, 1], 1024.0, tl.float32)
    tmp14 = (tmp12 / tmp13)
    tmp15 = tl.full([1, 1], 1e-05, tl.float32)
    tmp16 = tmp14 + tmp15
    tmp17 = libdevice.rsqrt(tmp16)
    tmp18 = tmp7 * tmp17
    tmp19 = tmp18.to(tl.float32)
    tmp21 = tmp19 * tmp20
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp21, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/rw/crwo3ravsojr6dslmng6oso6xwfjynjkghsrdssy5dsfswa6cemy.py
# Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_18 => add_22
#   attn_output_11 => view_50
#   attn_output_7 => view_32
#   hidden_6 => convert_element_type_75
#   hidden_states_11 => add_18
#   hidden_states_13 => view_56
#   hidden_states_14 => add_21
#   hidden_states_15 => mul_28
#   hidden_states_6 => add_11
#   hidden_states_8 => view_38
#   hidden_states_9 => add_14
#   mul_27 => mul_27
#   pow_7 => pow_7
#   rsqrt_6 => rsqrt_6
#   to_24 => convert_element_type_74
#   variance_6 => mean_6
# Graph fragment:
#   %add_7 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_7]
#   %mm_11 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_11]
#   %mm_14 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_14]
#   %mm_18 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_18]
#   %mm_21 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_21]
#   %add_21 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_21]
#   %buf69 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf69]
#   %arg33_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg33_1]
#   %view_32 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_11, [1, 5, 1024]), kwargs = {})
#   %add_11 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %view_32), kwargs = {})
#   %view_38 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_14, [1, 5, 1024]), kwargs = {})
#   %add_14 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_18, [1, 5, 1024]), kwargs = {})
#   %add_18 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_14, %view_50), kwargs = {})
#   %view_56 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_21, [1, 5, 1024]), kwargs = {})
#   %add_21 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_18, %view_56), kwargs = {})
#   %convert_element_type_74 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_21, torch.float32), kwargs = {})
#   %pow_7 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_74, 2), kwargs = {})
#   %mean_6 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_7, [-1], True), kwargs = {})
#   %add_22 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_6, 1e-05), kwargs = {})
#   %rsqrt_6 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_22,), kwargs = {})
#   %mul_27 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_21, %rsqrt_6), kwargs = {})
#   %convert_element_type_75 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.bfloat16), kwargs = {})
#   %mul_28 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_75, %arg33_1), kwargs = {})
#   return %add_21,%buf69,%mul_28
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 94208}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp22 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp8 = tmp6 + tmp7
    tmp9 = tmp8.to(tl.float32)
    tmp10 = tmp9 * tmp9
    tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
    tmp13 = tl.where(xmask, tmp11, 0)
    tmp14 = tl.sum(tmp13, 1)[:, None].to(tl.float32)
    tmp15 = tl.full([1, 1], 1024.0, tl.float32)
    tmp16 = (tmp14 / tmp15)
    tmp17 = tl.full([1, 1], 1e-05, tl.float32)
    tmp18 = tmp16 + tmp17
    tmp19 = libdevice.rsqrt(tmp18)
    tmp20 = tmp9 * tmp19
    tmp21 = tmp20.to(tl.float32)
    tmp23 = tmp21 * tmp22
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp8, xmask)
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp23, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/sr/csrhsvygq6hcqqhshfjol7fftabbo7754cdyhitmv2zefkmmamwq.py
# Topologically Sorted Source Nodes: [attn_output_47, hidden_states_56, hidden_states_58, hidden_states_59, to_96, pow_25, variance_24, add_72, rsqrt_24, mul_108, hidden_24, hidden_states_60], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_72 => add_85
#   attn_output_47 => view_212
#   hidden_24 => convert_element_type_291
#   hidden_states_56 => add_81
#   hidden_states_58 => view_218
#   hidden_states_59 => add_84
#   hidden_states_60 => mul_109
#   mul_108 => mul_108
#   pow_25 => pow_25
#   rsqrt_24 => rsqrt_24
#   to_96 => convert_element_type_290
#   variance_24 => mean_24
# Graph fragment:
#   %add_77 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_77]
#   %mm_81 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_81]
#   %mm_84 : Tensor "bf16[5, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_84]
#   %buf271 : Tensor "f32[1, 5, 1][5, 1, 5]cuda:0" = PlaceHolder[target=buf271]
#   %arg114_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg114_1]
#   %view_212 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_81, [1, 5, 1024]), kwargs = {})
#   %add_81 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_77, %view_212), kwargs = {})
#   %view_218 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_84, [1, 5, 1024]), kwargs = {})
#   %add_84 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_81, %view_218), kwargs = {})
#   %convert_element_type_290 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_84, torch.float32), kwargs = {})
#   %pow_25 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_290, 2), kwargs = {})
#   %mean_24 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_25, [-1], True), kwargs = {})
#   %add_85 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_24, 1e-05), kwargs = {})
#   %rsqrt_24 : Tensor "f32[1, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_85,), kwargs = {})
#   %mul_108 : Tensor "f32[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_84, %rsqrt_24), kwargs = {})
#   %convert_element_type_291 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_108, torch.bfloat16), kwargs = {})
#   %mul_109 : Tensor "bf16[1, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_291, %arg114_1), kwargs = {})
#   return %buf271,%mul_109
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 8, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 53248}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 5
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp18 = tl.load(in_ptr2 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp5 = tmp4.to(tl.float32)
    tmp6 = tmp5 * tmp5
    tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
    tmp9 = tl.where(xmask, tmp7, 0)
    tmp10 = tl.sum(tmp9, 1)[:, None].to(tl.float32)
    tmp11 = tl.full([1, 1], 1024.0, tl.float32)
    tmp12 = (tmp10 / tmp11)
    tmp13 = tl.full([1, 1], 1e-05, tl.float32)
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp5 * tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp19 = tmp17 * tmp18
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp19, xmask)
''', device_str='cuda')

def partition_0(args):
    arg0_1, arg1_1, arg3_1, arg2_1, arg6_1, arg7_1, arg8_1, arg9_1, arg4_1, arg5_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1, 1, 4, 64), (256, 4, 1, 4))
    assert_size_stride(arg1_1, (1024, 64), (64, 1))
    assert_size_stride(arg3_1, (1, 1, 1, 1024), (1024, 1024, 1024, 1))
    assert_size_stride(arg2_1, (1024, ), (1, ))
    assert_size_stride(arg6_1, (1024, ), (1, ))
    assert_size_stride(arg7_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg8_1, (256, 1024), (1024, 1))
    assert_size_stride(arg9_1, (256, 1024), (1024, 1))
    assert_size_stride(arg4_1, (32768, 128), (128, 1))
    assert_size_stride(arg5_1, (32768, 128), (128, 1))
    assert_size_stride(arg10_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg11_1, (1024, ), (1, ))
    assert_size_stride(arg12_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg13_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg14_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg15_1, (1024, ), (1, ))
    assert_size_stride(arg16_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg17_1, (256, 1024), (1024, 1))
    assert_size_stride(arg18_1, (256, 1024), (1024, 1))
    assert_size_stride(arg19_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg20_1, (1024, ), (1, ))
    assert_size_stride(arg21_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg22_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg23_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg24_1, (1024, ), (1, ))
    assert_size_stride(arg25_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg26_1, (256, 1024), (1024, 1))
    assert_size_stride(arg27_1, (256, 1024), (1024, 1))
    assert_size_stride(arg28_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg29_1, (1024, ), (1, ))
    assert_size_stride(arg30_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg31_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg32_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg33_1, (1024, ), (1, ))
    assert_size_stride(arg34_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg35_1, (256, 1024), (1024, 1))
    assert_size_stride(arg36_1, (256, 1024), (1024, 1))
    assert_size_stride(arg37_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg38_1, (1024, ), (1, ))
    assert_size_stride(arg39_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg40_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg41_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg42_1, (1024, ), (1, ))
    assert_size_stride(arg43_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg44_1, (256, 1024), (1024, 1))
    assert_size_stride(arg45_1, (256, 1024), (1024, 1))
    assert_size_stride(arg46_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg47_1, (1024, ), (1, ))
    assert_size_stride(arg48_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg49_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg50_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg51_1, (1024, ), (1, ))
    assert_size_stride(arg52_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg53_1, (256, 1024), (1024, 1))
    assert_size_stride(arg54_1, (256, 1024), (1024, 1))
    assert_size_stride(arg55_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg56_1, (1024, ), (1, ))
    assert_size_stride(arg57_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg58_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg59_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg60_1, (1024, ), (1, ))
    assert_size_stride(arg61_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg62_1, (256, 1024), (1024, 1))
    assert_size_stride(arg63_1, (256, 1024), (1024, 1))
    assert_size_stride(arg64_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg65_1, (1024, ), (1, ))
    assert_size_stride(arg66_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg67_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg68_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg69_1, (1024, ), (1, ))
    assert_size_stride(arg70_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg71_1, (256, 1024), (1024, 1))
    assert_size_stride(arg72_1, (256, 1024), (1024, 1))
    assert_size_stride(arg73_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg74_1, (1024, ), (1, ))
    assert_size_stride(arg75_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg76_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg77_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg78_1, (1024, ), (1, ))
    assert_size_stride(arg79_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg80_1, (256, 1024), (1024, 1))
    assert_size_stride(arg81_1, (256, 1024), (1024, 1))
    assert_size_stride(arg82_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg83_1, (1024, ), (1, ))
    assert_size_stride(arg84_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg85_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg86_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg87_1, (1024, ), (1, ))
    assert_size_stride(arg88_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg89_1, (256, 1024), (1024, 1))
    assert_size_stride(arg90_1, (256, 1024), (1024, 1))
    assert_size_stride(arg91_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg92_1, (1024, ), (1, ))
    assert_size_stride(arg93_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg94_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg95_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg96_1, (1024, ), (1, ))
    assert_size_stride(arg97_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg98_1, (256, 1024), (1024, 1))
    assert_size_stride(arg99_1, (256, 1024), (1024, 1))
    assert_size_stride(arg100_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg101_1, (1024, ), (1, ))
    assert_size_stride(arg102_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg103_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg104_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg105_1, (1024, ), (1, ))
    assert_size_stride(arg106_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg107_1, (256, 1024), (1024, 1))
    assert_size_stride(arg108_1, (256, 1024), (1024, 1))
    assert_size_stride(arg109_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg110_1, (1024, ), (1, ))
    assert_size_stride(arg111_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg112_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg113_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg114_1, (1024, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((4, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [x], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(arg0_1, (4, 64), (1, 4), 0), reinterpret_tensor(arg1_1, (64, 1024), (1, 64), 0), out=buf0)
        del arg0_1
        del arg1_1
        buf2 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten.expand, aten._unsafe_view, aten.add, aten.cat, aten.view, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_0.run(arg3_1, buf0, arg2_1, arg6_1, buf2, 5, 1024, stream=stream0)
        del arg6_1
        buf3 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg7_1, (1024, 2048), (1, 1024), 0), out=buf3)
        del arg7_1
        buf4 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg8_1, (1024, 256), (1, 1024), 0), out=buf4)
        del arg8_1
        buf5 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg9_1, (1024, 256), (1, 1024), 0), out=buf5)
        del arg9_1
        buf6 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf3, arg4_1, arg5_1, buf6, 10240, stream=stream0)
        del buf3
        buf7 = empty_strided_cuda((1, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf4, arg4_1, arg5_1, buf7, 1280, stream=stream0)
        buf8 = reinterpret_tensor(buf4, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf4  # reuse
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf5, buf8, 1280, stream=stream0)
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf9 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf6, buf7, buf8, scale=0.08838834764831843)
        buf10 = buf9[0]
        assert_size_stride(buf10, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf10, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf9
        buf15 = reinterpret_tensor(buf6, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf6  # reuse
        # Topologically Sorted Source Nodes: [transpose_3, attn_output_1], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf10, buf15, 10240, stream=stream0)
        del buf10
        buf16 = reinterpret_tensor(buf2, (5, 1024), (1024, 1), 0); del buf2  # reuse
        # Topologically Sorted Source Nodes: [transpose_3, attn_output_1, attn_output_2, attn_output_3], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf15, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg10_1, (2048, 1024), (1, 2048), 0), out=buf16)
        del arg10_1
        del buf15
        buf18 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.expand, aten._unsafe_view, aten.add, aten.cat, aten.view, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5.run(arg3_1, buf0, arg2_1, buf16, arg11_1, buf18, 5, 1024, stream=stream0)
        del arg11_1
        buf19 = empty_strided_cuda((5, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_5], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf18, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg12_1, (1024, 4096), (1, 1024), 0), out=buf19)
        del arg12_1
        buf20 = empty_strided_cuda((5, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_6], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf18, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg13_1, (1024, 4096), (1, 1024), 0), out=buf20)
        del arg13_1
        buf21 = reinterpret_tensor(buf19, (1, 5, 4096), (20480, 4096, 1), 0); del buf19  # reuse
        # Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf21, buf20, 20480, stream=stream0)
        buf22 = reinterpret_tensor(buf18, (5, 1024), (1024, 1), 0); del buf18  # reuse
        # Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8, hidden_states_3], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf21, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg14_1, (4096, 1024), (1, 4096), 0), out=buf22)
        del arg14_1
        buf23 = reinterpret_tensor(buf16, (1, 5, 1024), (5120, 1024, 1), 0); del buf16  # reuse
        buf25 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.expand, aten._unsafe_view, aten.add, aten.cat, aten.view, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7.run(buf23, arg3_1, buf0, arg2_1, buf22, arg15_1, buf25, 5, 1024, stream=stream0)
        del arg15_1
        del arg2_1
        del arg3_1
        del buf0
        buf26 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5, query_states_4], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf25, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg16_1, (1024, 2048), (1, 1024), 0), out=buf26)
        del arg16_1
        buf27 = reinterpret_tensor(buf8, (5, 256), (256, 1), 0); del buf8  # reuse
        # Topologically Sorted Source Nodes: [key_states_4], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf25, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg17_1, (1024, 256), (1, 1024), 0), out=buf27)
        del arg17_1
        buf28 = reinterpret_tensor(buf7, (5, 256), (256, 1), 0); del buf7  # reuse
        # Topologically Sorted Source Nodes: [value_states_3], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf25, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg18_1, (1024, 256), (1, 1024), 0), out=buf28)
        del arg18_1
        buf29 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf26, arg4_1, arg5_1, buf29, 10240, stream=stream0)
        del buf26
        buf30 = reinterpret_tensor(buf5, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf5  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf27, arg4_1, arg5_1, buf30, 1280, stream=stream0)
        buf31 = reinterpret_tensor(buf27, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf27  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf28, buf31, 1280, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf32 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf29, buf30, buf31, scale=0.08838834764831843)
        buf33 = buf32[0]
        assert_size_stride(buf33, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf33, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf32
        buf38 = reinterpret_tensor(buf29, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf29  # reuse
        # Topologically Sorted Source Nodes: [transpose_7, attn_output_5], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf33, buf38, 10240, stream=stream0)
        del buf33
        buf39 = reinterpret_tensor(buf25, (5, 1024), (1024, 1), 0); del buf25  # reuse
        # Topologically Sorted Source Nodes: [transpose_7, attn_output_5, attn_output_6, attn_output_7], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf38, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg19_1, (2048, 1024), (1, 2048), 0), out=buf39)
        del arg19_1
        del buf38
        buf41 = reinterpret_tensor(buf22, (1, 5, 1024), (5120, 1024, 1), 0); del buf22  # reuse
        # Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf23, buf39, arg20_1, buf41, 5, 1024, stream=stream0)
        del arg20_1
        buf42 = reinterpret_tensor(buf21, (5, 4096), (4096, 1), 0); del buf21  # reuse
        # Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7, linear_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf41, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg21_1, (1024, 4096), (1, 1024), 0), out=buf42)
        del arg21_1
        buf43 = buf20; del buf20  # reuse
        # Topologically Sorted Source Nodes: [linear_13], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf41, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg22_1, (1024, 4096), (1, 1024), 0), out=buf43)
        del arg22_1
        buf44 = reinterpret_tensor(buf42, (1, 5, 4096), (20480, 4096, 1), 0); del buf42  # reuse
        # Topologically Sorted Source Nodes: [linear_12, silu_1, linear_13, mul_17], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf44, buf43, 20480, stream=stream0)
        del buf43
        buf45 = reinterpret_tensor(buf41, (5, 1024), (1024, 1), 0); del buf41  # reuse
        # Topologically Sorted Source Nodes: [linear_12, silu_1, linear_13, mul_17, hidden_states_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf44, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg23_1, (4096, 1024), (1, 4096), 0), out=buf45)
        del arg23_1
        buf47 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf23, buf39, buf45, arg24_1, buf47, 5, 1024, stream=stream0)
        del arg24_1
        buf48 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_8], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg25_1, (1024, 2048), (1, 1024), 0), out=buf48)
        del arg25_1
        buf49 = reinterpret_tensor(buf31, (5, 256), (256, 1), 0); del buf31  # reuse
        # Topologically Sorted Source Nodes: [key_states_8], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg26_1, (1024, 256), (1, 1024), 0), out=buf49)
        del arg26_1
        buf50 = reinterpret_tensor(buf30, (5, 256), (256, 1), 0); del buf30  # reuse
        # Topologically Sorted Source Nodes: [value_states_6], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg27_1, (1024, 256), (1, 1024), 0), out=buf50)
        del arg27_1
        buf51 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf48, arg4_1, arg5_1, buf51, 10240, stream=stream0)
        del buf48
        buf52 = reinterpret_tensor(buf28, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf28  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf49, arg4_1, arg5_1, buf52, 1280, stream=stream0)
        buf53 = reinterpret_tensor(buf49, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf49  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf50, buf53, 1280, stream=stream0)
        del buf50
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf54 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf51, buf52, buf53, scale=0.08838834764831843)
        del buf52
        del buf53
        buf55 = buf54[0]
        assert_size_stride(buf55, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf55, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf54
        buf60 = reinterpret_tensor(buf51, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf51  # reuse
        # Topologically Sorted Source Nodes: [transpose_11, attn_output_9], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf55, buf60, 10240, stream=stream0)
        del buf55
        buf61 = reinterpret_tensor(buf47, (5, 1024), (1024, 1), 0); del buf47  # reuse
        # Topologically Sorted Source Nodes: [transpose_11, attn_output_9, attn_output_10, attn_output_11], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf60, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg28_1, (2048, 1024), (1, 2048), 0), out=buf61)
        del arg28_1
        del buf60
        buf63 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, to_22, pow_6, variance_5, add_16, rsqrt_5, mul_24, hidden_5, hidden_states_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf23, buf39, buf45, buf61, arg29_1, buf63, 5, 1024, stream=stream0)
        del arg29_1
        buf64 = reinterpret_tensor(buf44, (5, 4096), (4096, 1), 0); del buf44  # reuse
        # Topologically Sorted Source Nodes: [linear_19], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf63, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg30_1, (1024, 4096), (1, 1024), 0), out=buf64)
        del arg30_1
        buf65 = empty_strided_cuda((5, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf63, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg31_1, (1024, 4096), (1, 1024), 0), out=buf65)
        del arg31_1
        del buf63
        buf66 = reinterpret_tensor(buf64, (1, 5, 4096), (20480, 4096, 1), 0); del buf64  # reuse
        # Topologically Sorted Source Nodes: [linear_19, silu_2, linear_20, mul_26], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf66, buf65, 20480, stream=stream0)
        buf67 = empty_strided_cuda((5, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_19, silu_2, linear_20, mul_26, hidden_states_13], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf66, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg32_1, (4096, 1024), (1, 4096), 0), out=buf67)
        del arg32_1
        buf68 = buf23; del buf23  # reuse
        buf70 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf68, buf39, buf45, buf61, buf67, arg33_1, buf70, 5, 1024, stream=stream0)
        del arg33_1
        del buf39
        del buf45
        del buf61
        buf71 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15, query_states_12], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf70, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg34_1, (1024, 2048), (1, 1024), 0), out=buf71)
        del arg34_1
        buf72 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf70, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg35_1, (1024, 256), (1, 1024), 0), out=buf72)
        del arg35_1
        buf73 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_9], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf70, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg36_1, (1024, 256), (1, 1024), 0), out=buf73)
        del arg36_1
        buf74 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf71, arg4_1, arg5_1, buf74, 10240, stream=stream0)
        del buf71
        buf75 = empty_strided_cuda((1, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf72, arg4_1, arg5_1, buf75, 1280, stream=stream0)
        buf76 = reinterpret_tensor(buf72, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf72  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf73, buf76, 1280, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf77 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf74, buf75, buf76, scale=0.08838834764831843)
        buf78 = buf77[0]
        assert_size_stride(buf78, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf78, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf77
        buf83 = reinterpret_tensor(buf74, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf74  # reuse
        # Topologically Sorted Source Nodes: [transpose_15, attn_output_13], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf78, buf83, 10240, stream=stream0)
        del buf78
        buf84 = reinterpret_tensor(buf70, (5, 1024), (1024, 1), 0); del buf70  # reuse
        # Topologically Sorted Source Nodes: [transpose_15, attn_output_13, attn_output_14, attn_output_15], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf83, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg37_1, (2048, 1024), (1, 2048), 0), out=buf84)
        del arg37_1
        del buf83
        buf86 = reinterpret_tensor(buf67, (1, 5, 1024), (5120, 1024, 1), 0); del buf67  # reuse
        # Topologically Sorted Source Nodes: [attn_output_15, hidden_states_16, to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf68, buf84, arg38_1, buf86, 5, 1024, stream=stream0)
        del arg38_1
        buf87 = reinterpret_tensor(buf66, (5, 4096), (4096, 1), 0); del buf66  # reuse
        # Topologically Sorted Source Nodes: [attn_output_15, hidden_states_16, to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17, linear_26], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf86, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg39_1, (1024, 4096), (1, 1024), 0), out=buf87)
        del arg39_1
        buf88 = buf65; del buf65  # reuse
        # Topologically Sorted Source Nodes: [linear_27], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf86, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg40_1, (1024, 4096), (1, 1024), 0), out=buf88)
        del arg40_1
        buf89 = reinterpret_tensor(buf87, (1, 5, 4096), (20480, 4096, 1), 0); del buf87  # reuse
        # Topologically Sorted Source Nodes: [linear_26, silu_3, linear_27, mul_35], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf89, buf88, 20480, stream=stream0)
        del buf88
        buf90 = reinterpret_tensor(buf86, (5, 1024), (1024, 1), 0); del buf86  # reuse
        # Topologically Sorted Source Nodes: [linear_26, silu_3, linear_27, mul_35, hidden_states_18], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf89, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg41_1, (4096, 1024), (1, 4096), 0), out=buf90)
        del arg41_1
        buf92 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_15, hidden_states_16, hidden_states_18, hidden_states_19, to_32, pow_9, variance_8, add_24, rsqrt_8, mul_36, hidden_8, hidden_states_20], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf68, buf84, buf90, arg42_1, buf92, 5, 1024, stream=stream0)
        del arg42_1
        buf93 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_16], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg43_1, (1024, 2048), (1, 1024), 0), out=buf93)
        del arg43_1
        buf94 = reinterpret_tensor(buf76, (5, 256), (256, 1), 0); del buf76  # reuse
        # Topologically Sorted Source Nodes: [key_states_16], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg44_1, (1024, 256), (1, 1024), 0), out=buf94)
        del arg44_1
        buf95 = reinterpret_tensor(buf75, (5, 256), (256, 1), 0); del buf75  # reuse
        # Topologically Sorted Source Nodes: [value_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg45_1, (1024, 256), (1, 1024), 0), out=buf95)
        del arg45_1
        buf96 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf93, arg4_1, arg5_1, buf96, 10240, stream=stream0)
        del buf93
        buf97 = reinterpret_tensor(buf73, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf73  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf94, arg4_1, arg5_1, buf97, 1280, stream=stream0)
        buf98 = reinterpret_tensor(buf94, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf94  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf95, buf98, 1280, stream=stream0)
        del buf95
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf99 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf96, buf97, buf98, scale=0.08838834764831843)
        del buf97
        del buf98
        buf100 = buf99[0]
        assert_size_stride(buf100, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf100, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf99
        buf105 = reinterpret_tensor(buf96, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf96  # reuse
        # Topologically Sorted Source Nodes: [transpose_19, attn_output_17], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf100, buf105, 10240, stream=stream0)
        del buf100
        buf106 = reinterpret_tensor(buf92, (5, 1024), (1024, 1), 0); del buf92  # reuse
        # Topologically Sorted Source Nodes: [transpose_19, attn_output_17, attn_output_18, attn_output_19], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf105, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg46_1, (2048, 1024), (1, 2048), 0), out=buf106)
        del arg46_1
        del buf105
        buf108 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_15, hidden_states_16, hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, to_38, pow_10, variance_9, add_28, rsqrt_9, mul_42, hidden_9, hidden_states_22], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf68, buf84, buf90, buf106, arg47_1, buf108, 5, 1024, stream=stream0)
        del arg47_1
        buf109 = reinterpret_tensor(buf89, (5, 4096), (4096, 1), 0); del buf89  # reuse
        # Topologically Sorted Source Nodes: [linear_33], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf108, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg48_1, (1024, 4096), (1, 1024), 0), out=buf109)
        del arg48_1
        buf110 = empty_strided_cuda((5, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_34], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf108, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg49_1, (1024, 4096), (1, 1024), 0), out=buf110)
        del arg49_1
        del buf108
        buf111 = reinterpret_tensor(buf109, (1, 5, 4096), (20480, 4096, 1), 0); del buf109  # reuse
        # Topologically Sorted Source Nodes: [linear_33, silu_4, linear_34, mul_44], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf111, buf110, 20480, stream=stream0)
        buf112 = empty_strided_cuda((5, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_33, silu_4, linear_34, mul_44, hidden_states_23], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf111, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg50_1, (4096, 1024), (1, 4096), 0), out=buf112)
        del arg50_1
        buf113 = buf68; del buf68  # reuse
        buf115 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_15, hidden_states_16, hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, to_40, pow_11, variance_10, add_30, rsqrt_10, mul_45, hidden_10, hidden_states_25], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf113, buf84, buf90, buf106, buf112, arg51_1, buf115, 5, 1024, stream=stream0)
        del arg51_1
        del buf106
        del buf112
        del buf84
        buf116 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_40, pow_11, variance_10, add_30, rsqrt_10, mul_45, hidden_10, hidden_states_25, query_states_20], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf115, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg52_1, (1024, 2048), (1, 1024), 0), out=buf116)
        del arg52_1
        buf117 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf115, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg53_1, (1024, 256), (1, 1024), 0), out=buf117)
        del arg53_1
        buf118 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_15], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf115, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg54_1, (1024, 256), (1, 1024), 0), out=buf118)
        del arg54_1
        buf119 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf116, arg4_1, arg5_1, buf119, 10240, stream=stream0)
        del buf116
        buf120 = empty_strided_cuda((1, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf117, arg4_1, arg5_1, buf120, 1280, stream=stream0)
        buf121 = reinterpret_tensor(buf117, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf117  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf118, buf121, 1280, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf122 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf119, buf120, buf121, scale=0.08838834764831843)
        buf123 = buf122[0]
        assert_size_stride(buf123, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf123, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf122
        buf128 = reinterpret_tensor(buf119, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf119  # reuse
        # Topologically Sorted Source Nodes: [transpose_23, attn_output_21], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf123, buf128, 10240, stream=stream0)
        del buf123
        buf129 = reinterpret_tensor(buf115, (5, 1024), (1024, 1), 0); del buf115  # reuse
        # Topologically Sorted Source Nodes: [transpose_23, attn_output_21, attn_output_22, attn_output_23], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf128, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg55_1, (2048, 1024), (1, 2048), 0), out=buf129)
        del arg55_1
        del buf128
        buf131 = reinterpret_tensor(buf90, (1, 5, 1024), (5120, 1024, 1), 0); del buf90  # reuse
        # Topologically Sorted Source Nodes: [attn_output_23, hidden_states_26, to_46, pow_12, variance_11, add_34, rsqrt_11, mul_51, hidden_11, hidden_states_27], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf113, buf129, arg56_1, buf131, 5, 1024, stream=stream0)
        del arg56_1
        buf132 = reinterpret_tensor(buf111, (5, 4096), (4096, 1), 0); del buf111  # reuse
        # Topologically Sorted Source Nodes: [attn_output_23, hidden_states_26, to_46, pow_12, variance_11, add_34, rsqrt_11, mul_51, hidden_11, hidden_states_27, linear_40], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf131, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg57_1, (1024, 4096), (1, 1024), 0), out=buf132)
        del arg57_1
        buf133 = buf110; del buf110  # reuse
        # Topologically Sorted Source Nodes: [linear_41], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf131, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg58_1, (1024, 4096), (1, 1024), 0), out=buf133)
        del arg58_1
        buf134 = reinterpret_tensor(buf132, (1, 5, 4096), (20480, 4096, 1), 0); del buf132  # reuse
        # Topologically Sorted Source Nodes: [linear_40, silu_5, linear_41, mul_53], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf134, buf133, 20480, stream=stream0)
        del buf133
        buf135 = reinterpret_tensor(buf131, (5, 1024), (1024, 1), 0); del buf131  # reuse
        # Topologically Sorted Source Nodes: [linear_40, silu_5, linear_41, mul_53, hidden_states_28], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf134, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg59_1, (4096, 1024), (1, 4096), 0), out=buf135)
        del arg59_1
        buf137 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_23, hidden_states_26, hidden_states_28, hidden_states_29, to_48, pow_13, variance_12, add_36, rsqrt_12, mul_54, hidden_12, hidden_states_30], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf113, buf129, buf135, arg60_1, buf137, 5, 1024, stream=stream0)
        del arg60_1
        buf138 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg61_1, (1024, 2048), (1, 1024), 0), out=buf138)
        del arg61_1
        buf139 = reinterpret_tensor(buf121, (5, 256), (256, 1), 0); del buf121  # reuse
        # Topologically Sorted Source Nodes: [key_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg62_1, (1024, 256), (1, 1024), 0), out=buf139)
        del arg62_1
        buf140 = reinterpret_tensor(buf120, (5, 256), (256, 1), 0); del buf120  # reuse
        # Topologically Sorted Source Nodes: [value_states_18], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg63_1, (1024, 256), (1, 1024), 0), out=buf140)
        del arg63_1
        buf141 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf138, arg4_1, arg5_1, buf141, 10240, stream=stream0)
        del buf138
        buf142 = reinterpret_tensor(buf118, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf118  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf139, arg4_1, arg5_1, buf142, 1280, stream=stream0)
        buf143 = reinterpret_tensor(buf139, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf139  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf140, buf143, 1280, stream=stream0)
        del buf140
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf144 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf141, buf142, buf143, scale=0.08838834764831843)
        del buf142
        del buf143
        buf145 = buf144[0]
        assert_size_stride(buf145, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf145, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf144
        buf150 = reinterpret_tensor(buf141, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf141  # reuse
        # Topologically Sorted Source Nodes: [transpose_27, attn_output_25], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf145, buf150, 10240, stream=stream0)
        del buf145
        buf151 = reinterpret_tensor(buf137, (5, 1024), (1024, 1), 0); del buf137  # reuse
        # Topologically Sorted Source Nodes: [transpose_27, attn_output_25, attn_output_26, attn_output_27], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf150, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg64_1, (2048, 1024), (1, 2048), 0), out=buf151)
        del arg64_1
        del buf150
        buf153 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_23, hidden_states_26, hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, to_54, pow_14, variance_13, add_40, rsqrt_13, mul_60, hidden_13, hidden_states_32], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf113, buf129, buf135, buf151, arg65_1, buf153, 5, 1024, stream=stream0)
        del arg65_1
        buf154 = reinterpret_tensor(buf134, (5, 4096), (4096, 1), 0); del buf134  # reuse
        # Topologically Sorted Source Nodes: [linear_47], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf153, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg66_1, (1024, 4096), (1, 1024), 0), out=buf154)
        del arg66_1
        buf155 = empty_strided_cuda((5, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_48], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf153, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg67_1, (1024, 4096), (1, 1024), 0), out=buf155)
        del arg67_1
        del buf153
        buf156 = reinterpret_tensor(buf154, (1, 5, 4096), (20480, 4096, 1), 0); del buf154  # reuse
        # Topologically Sorted Source Nodes: [linear_47, silu_6, linear_48, mul_62], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf156, buf155, 20480, stream=stream0)
        buf157 = empty_strided_cuda((5, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_47, silu_6, linear_48, mul_62, hidden_states_33], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf156, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg68_1, (4096, 1024), (1, 4096), 0), out=buf157)
        del arg68_1
        buf158 = buf113; del buf113  # reuse
        buf160 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_23, hidden_states_26, hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, to_56, pow_15, variance_14, add_42, rsqrt_14, mul_63, hidden_14, hidden_states_35], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf158, buf129, buf135, buf151, buf157, arg69_1, buf160, 5, 1024, stream=stream0)
        del arg69_1
        del buf129
        del buf135
        del buf151
        buf161 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_56, pow_15, variance_14, add_42, rsqrt_14, mul_63, hidden_14, hidden_states_35, query_states_28], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf160, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg70_1, (1024, 2048), (1, 1024), 0), out=buf161)
        del arg70_1
        buf162 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_28], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf160, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg71_1, (1024, 256), (1, 1024), 0), out=buf162)
        del arg71_1
        buf163 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_21], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf160, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg72_1, (1024, 256), (1, 1024), 0), out=buf163)
        del arg72_1
        buf164 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf161, arg4_1, arg5_1, buf164, 10240, stream=stream0)
        del buf161
        buf165 = empty_strided_cuda((1, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf162, arg4_1, arg5_1, buf165, 1280, stream=stream0)
        buf166 = reinterpret_tensor(buf162, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf162  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf163, buf166, 1280, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf167 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf164, buf165, buf166, scale=0.08838834764831843)
        buf168 = buf167[0]
        assert_size_stride(buf168, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf168, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf167
        buf173 = reinterpret_tensor(buf164, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf164  # reuse
        # Topologically Sorted Source Nodes: [transpose_31, attn_output_29], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf168, buf173, 10240, stream=stream0)
        del buf168
        buf174 = reinterpret_tensor(buf160, (5, 1024), (1024, 1), 0); del buf160  # reuse
        # Topologically Sorted Source Nodes: [transpose_31, attn_output_29, attn_output_30, attn_output_31], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf173, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg73_1, (2048, 1024), (1, 2048), 0), out=buf174)
        del arg73_1
        del buf173
        buf176 = reinterpret_tensor(buf157, (1, 5, 1024), (5120, 1024, 1), 0); del buf157  # reuse
        # Topologically Sorted Source Nodes: [attn_output_31, hidden_states_36, to_62, pow_16, variance_15, add_46, rsqrt_15, mul_69, hidden_15, hidden_states_37], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf158, buf174, arg74_1, buf176, 5, 1024, stream=stream0)
        del arg74_1
        buf177 = reinterpret_tensor(buf156, (5, 4096), (4096, 1), 0); del buf156  # reuse
        # Topologically Sorted Source Nodes: [attn_output_31, hidden_states_36, to_62, pow_16, variance_15, add_46, rsqrt_15, mul_69, hidden_15, hidden_states_37, linear_54], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf176, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg75_1, (1024, 4096), (1, 1024), 0), out=buf177)
        del arg75_1
        buf178 = buf155; del buf155  # reuse
        # Topologically Sorted Source Nodes: [linear_55], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf176, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg76_1, (1024, 4096), (1, 1024), 0), out=buf178)
        del arg76_1
        buf179 = reinterpret_tensor(buf177, (1, 5, 4096), (20480, 4096, 1), 0); del buf177  # reuse
        # Topologically Sorted Source Nodes: [linear_54, silu_7, linear_55, mul_71], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf179, buf178, 20480, stream=stream0)
        del buf178
        buf180 = reinterpret_tensor(buf176, (5, 1024), (1024, 1), 0); del buf176  # reuse
        # Topologically Sorted Source Nodes: [linear_54, silu_7, linear_55, mul_71, hidden_states_38], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf179, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg77_1, (4096, 1024), (1, 4096), 0), out=buf180)
        del arg77_1
        buf182 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_31, hidden_states_36, hidden_states_38, hidden_states_39, to_64, pow_17, variance_16, add_48, rsqrt_16, mul_72, hidden_16, hidden_states_40], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf158, buf174, buf180, arg78_1, buf182, 5, 1024, stream=stream0)
        del arg78_1
        buf183 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_32], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg79_1, (1024, 2048), (1, 1024), 0), out=buf183)
        del arg79_1
        buf184 = reinterpret_tensor(buf166, (5, 256), (256, 1), 0); del buf166  # reuse
        # Topologically Sorted Source Nodes: [key_states_32], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg80_1, (1024, 256), (1, 1024), 0), out=buf184)
        del arg80_1
        buf185 = reinterpret_tensor(buf165, (5, 256), (256, 1), 0); del buf165  # reuse
        # Topologically Sorted Source Nodes: [value_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg81_1, (1024, 256), (1, 1024), 0), out=buf185)
        del arg81_1
        buf186 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf183, arg4_1, arg5_1, buf186, 10240, stream=stream0)
        del buf183
        buf187 = reinterpret_tensor(buf163, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf163  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf184, arg4_1, arg5_1, buf187, 1280, stream=stream0)
        buf188 = reinterpret_tensor(buf184, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf184  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf185, buf188, 1280, stream=stream0)
        del buf185
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf189 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf186, buf187, buf188, scale=0.08838834764831843)
        del buf187
        del buf188
        buf190 = buf189[0]
        assert_size_stride(buf190, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf190, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf189
        buf195 = reinterpret_tensor(buf186, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf186  # reuse
        # Topologically Sorted Source Nodes: [transpose_35, attn_output_33], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf190, buf195, 10240, stream=stream0)
        del buf190
        buf196 = reinterpret_tensor(buf182, (5, 1024), (1024, 1), 0); del buf182  # reuse
        # Topologically Sorted Source Nodes: [transpose_35, attn_output_33, attn_output_34, attn_output_35], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf195, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg82_1, (2048, 1024), (1, 2048), 0), out=buf196)
        del arg82_1
        del buf195
        buf198 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_31, hidden_states_36, hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, to_70, pow_18, variance_17, add_52, rsqrt_17, mul_78, hidden_17, hidden_states_42], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf158, buf174, buf180, buf196, arg83_1, buf198, 5, 1024, stream=stream0)
        del arg83_1
        buf199 = reinterpret_tensor(buf179, (5, 4096), (4096, 1), 0); del buf179  # reuse
        # Topologically Sorted Source Nodes: [linear_61], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf198, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg84_1, (1024, 4096), (1, 1024), 0), out=buf199)
        del arg84_1
        buf200 = empty_strided_cuda((5, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_62], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf198, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg85_1, (1024, 4096), (1, 1024), 0), out=buf200)
        del arg85_1
        del buf198
        buf201 = reinterpret_tensor(buf199, (1, 5, 4096), (20480, 4096, 1), 0); del buf199  # reuse
        # Topologically Sorted Source Nodes: [linear_61, silu_8, linear_62, mul_80], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf201, buf200, 20480, stream=stream0)
        buf202 = empty_strided_cuda((5, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_61, silu_8, linear_62, mul_80, hidden_states_43], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf201, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg86_1, (4096, 1024), (1, 4096), 0), out=buf202)
        del arg86_1
        buf203 = buf158; del buf158  # reuse
        buf205 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_31, hidden_states_36, hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, to_72, pow_19, variance_18, add_54, rsqrt_18, mul_81, hidden_18, hidden_states_45], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf203, buf174, buf180, buf196, buf202, arg87_1, buf205, 5, 1024, stream=stream0)
        del arg87_1
        del buf174
        del buf180
        del buf196
        buf206 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_72, pow_19, variance_18, add_54, rsqrt_18, mul_81, hidden_18, hidden_states_45, query_states_36], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf205, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg88_1, (1024, 2048), (1, 1024), 0), out=buf206)
        del arg88_1
        buf207 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_36], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf205, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg89_1, (1024, 256), (1, 1024), 0), out=buf207)
        del arg89_1
        buf208 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_27], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf205, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg90_1, (1024, 256), (1, 1024), 0), out=buf208)
        del arg90_1
        buf209 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf206, arg4_1, arg5_1, buf209, 10240, stream=stream0)
        del buf206
        buf210 = empty_strided_cuda((1, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf207, arg4_1, arg5_1, buf210, 1280, stream=stream0)
        buf211 = reinterpret_tensor(buf207, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf207  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf208, buf211, 1280, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf212 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf209, buf210, buf211, scale=0.08838834764831843)
        buf213 = buf212[0]
        assert_size_stride(buf213, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf213, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf212
        buf218 = reinterpret_tensor(buf209, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf209  # reuse
        # Topologically Sorted Source Nodes: [transpose_39, attn_output_37], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf213, buf218, 10240, stream=stream0)
        del buf213
        buf219 = reinterpret_tensor(buf205, (5, 1024), (1024, 1), 0); del buf205  # reuse
        # Topologically Sorted Source Nodes: [transpose_39, attn_output_37, attn_output_38, attn_output_39], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf218, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg91_1, (2048, 1024), (1, 2048), 0), out=buf219)
        del arg91_1
        del buf218
        buf221 = reinterpret_tensor(buf202, (1, 5, 1024), (5120, 1024, 1), 0); del buf202  # reuse
        # Topologically Sorted Source Nodes: [attn_output_39, hidden_states_46, to_78, pow_20, variance_19, add_58, rsqrt_19, mul_87, hidden_19, hidden_states_47], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf203, buf219, arg92_1, buf221, 5, 1024, stream=stream0)
        del arg92_1
        buf222 = reinterpret_tensor(buf201, (5, 4096), (4096, 1), 0); del buf201  # reuse
        # Topologically Sorted Source Nodes: [attn_output_39, hidden_states_46, to_78, pow_20, variance_19, add_58, rsqrt_19, mul_87, hidden_19, hidden_states_47, linear_68], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf221, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg93_1, (1024, 4096), (1, 1024), 0), out=buf222)
        del arg93_1
        buf223 = buf200; del buf200  # reuse
        # Topologically Sorted Source Nodes: [linear_69], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf221, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg94_1, (1024, 4096), (1, 1024), 0), out=buf223)
        del arg94_1
        buf224 = reinterpret_tensor(buf222, (1, 5, 4096), (20480, 4096, 1), 0); del buf222  # reuse
        # Topologically Sorted Source Nodes: [linear_68, silu_9, linear_69, mul_89], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf224, buf223, 20480, stream=stream0)
        del buf223
        buf225 = reinterpret_tensor(buf221, (5, 1024), (1024, 1), 0); del buf221  # reuse
        # Topologically Sorted Source Nodes: [linear_68, silu_9, linear_69, mul_89, hidden_states_48], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf224, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg95_1, (4096, 1024), (1, 4096), 0), out=buf225)
        del arg95_1
        buf227 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_39, hidden_states_46, hidden_states_48, hidden_states_49, to_80, pow_21, variance_20, add_60, rsqrt_20, mul_90, hidden_20, hidden_states_50], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf203, buf219, buf225, arg96_1, buf227, 5, 1024, stream=stream0)
        del arg96_1
        buf228 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_40], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg97_1, (1024, 2048), (1, 1024), 0), out=buf228)
        del arg97_1
        buf229 = reinterpret_tensor(buf211, (5, 256), (256, 1), 0); del buf211  # reuse
        # Topologically Sorted Source Nodes: [key_states_40], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg98_1, (1024, 256), (1, 1024), 0), out=buf229)
        del arg98_1
        buf230 = reinterpret_tensor(buf210, (5, 256), (256, 1), 0); del buf210  # reuse
        # Topologically Sorted Source Nodes: [value_states_30], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg99_1, (1024, 256), (1, 1024), 0), out=buf230)
        del arg99_1
        buf231 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf228, arg4_1, arg5_1, buf231, 10240, stream=stream0)
        del buf228
        buf232 = reinterpret_tensor(buf208, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf208  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf229, arg4_1, arg5_1, buf232, 1280, stream=stream0)
        buf233 = reinterpret_tensor(buf229, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf229  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf230, buf233, 1280, stream=stream0)
        del buf230
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf234 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf231, buf232, buf233, scale=0.08838834764831843)
        del buf232
        del buf233
        buf235 = buf234[0]
        assert_size_stride(buf235, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf235, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf234
        buf240 = reinterpret_tensor(buf231, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf231  # reuse
        # Topologically Sorted Source Nodes: [transpose_43, attn_output_41], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf235, buf240, 10240, stream=stream0)
        del buf235
        buf241 = reinterpret_tensor(buf227, (5, 1024), (1024, 1), 0); del buf227  # reuse
        # Topologically Sorted Source Nodes: [transpose_43, attn_output_41, attn_output_42, attn_output_43], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf240, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg100_1, (2048, 1024), (1, 2048), 0), out=buf241)
        del arg100_1
        del buf240
        buf243 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_39, hidden_states_46, hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, to_86, pow_22, variance_21, add_64, rsqrt_21, mul_96, hidden_21, hidden_states_52], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf203, buf219, buf225, buf241, arg101_1, buf243, 5, 1024, stream=stream0)
        del arg101_1
        buf244 = reinterpret_tensor(buf224, (5, 4096), (4096, 1), 0); del buf224  # reuse
        # Topologically Sorted Source Nodes: [linear_75], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf243, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg102_1, (1024, 4096), (1, 1024), 0), out=buf244)
        del arg102_1
        buf245 = empty_strided_cuda((5, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_76], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf243, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg103_1, (1024, 4096), (1, 1024), 0), out=buf245)
        del arg103_1
        del buf243
        buf246 = reinterpret_tensor(buf244, (1, 5, 4096), (20480, 4096, 1), 0); del buf244  # reuse
        # Topologically Sorted Source Nodes: [linear_75, silu_10, linear_76, mul_98], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf246, buf245, 20480, stream=stream0)
        buf247 = empty_strided_cuda((5, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_75, silu_10, linear_76, mul_98, hidden_states_53], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf246, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg104_1, (4096, 1024), (1, 4096), 0), out=buf247)
        del arg104_1
        buf248 = buf203; del buf203  # reuse
        buf250 = empty_strided_cuda((1, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_39, hidden_states_46, hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, to_88, pow_23, variance_22, add_66, rsqrt_22, mul_99, hidden_22, hidden_states_55], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf248, buf219, buf225, buf241, buf247, arg105_1, buf250, 5, 1024, stream=stream0)
        del arg105_1
        del buf219
        del buf225
        del buf241
        buf251 = empty_strided_cuda((5, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_88, pow_23, variance_22, add_66, rsqrt_22, mul_99, hidden_22, hidden_states_55, query_states_44], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf250, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg106_1, (1024, 2048), (1, 1024), 0), out=buf251)
        del arg106_1
        buf252 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_44], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf250, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg107_1, (1024, 256), (1, 1024), 0), out=buf252)
        del arg107_1
        buf253 = empty_strided_cuda((5, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_33], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf250, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg108_1, (1024, 256), (1, 1024), 0), out=buf253)
        del arg108_1
        buf254 = empty_strided_cuda((1, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf251, arg4_1, arg5_1, buf254, 10240, stream=stream0)
        del buf251
        buf255 = empty_strided_cuda((1, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf252, arg4_1, arg5_1, buf255, 1280, stream=stream0)
        del arg4_1
        del arg5_1
        buf256 = reinterpret_tensor(buf252, (1, 2, 5, 128), (1280, 640, 128, 1), 0); del buf252  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf253, buf256, 1280, stream=stream0)
        del buf253
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf257 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf254, buf255, buf256, scale=0.08838834764831843)
        del buf255
        del buf256
        buf258 = buf257[0]
        assert_size_stride(buf258, (1, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf258, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf257
        buf263 = reinterpret_tensor(buf254, (1, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf254  # reuse
        # Topologically Sorted Source Nodes: [transpose_47, attn_output_45], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf258, buf263, 10240, stream=stream0)
        del buf258
        buf264 = reinterpret_tensor(buf250, (5, 1024), (1024, 1), 0); del buf250  # reuse
        # Topologically Sorted Source Nodes: [transpose_47, attn_output_45, attn_output_46, attn_output_47], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf263, (5, 2048), (2048, 1), 0), reinterpret_tensor(arg109_1, (2048, 1024), (1, 2048), 0), out=buf264)
        del arg109_1
        del buf263
        buf266 = reinterpret_tensor(buf247, (1, 5, 1024), (5120, 1024, 1), 0); del buf247  # reuse
        # Topologically Sorted Source Nodes: [attn_output_47, hidden_states_56, to_94, pow_24, variance_23, add_70, rsqrt_23, mul_105, hidden_23, hidden_states_57], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf248, buf264, arg110_1, buf266, 5, 1024, stream=stream0)
        del arg110_1
        buf267 = reinterpret_tensor(buf246, (5, 4096), (4096, 1), 0); del buf246  # reuse
        # Topologically Sorted Source Nodes: [attn_output_47, hidden_states_56, to_94, pow_24, variance_23, add_70, rsqrt_23, mul_105, hidden_23, hidden_states_57, linear_82], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf266, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg111_1, (1024, 4096), (1, 1024), 0), out=buf267)
        del arg111_1
        buf268 = buf245; del buf245  # reuse
        # Topologically Sorted Source Nodes: [linear_83], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf266, (5, 1024), (1024, 1), 0), reinterpret_tensor(arg112_1, (1024, 4096), (1, 1024), 0), out=buf268)
        del arg112_1
        buf269 = reinterpret_tensor(buf267, (1, 5, 4096), (20480, 4096, 1), 0); del buf267  # reuse
        # Topologically Sorted Source Nodes: [linear_82, silu_11, linear_83, mul_107], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf269, buf268, 20480, stream=stream0)
        del buf268
        buf270 = reinterpret_tensor(buf266, (5, 1024), (1024, 1), 0); del buf266  # reuse
        # Topologically Sorted Source Nodes: [linear_82, silu_11, linear_83, mul_107, hidden_states_58], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf269, (5, 4096), (4096, 1), 0), reinterpret_tensor(arg113_1, (4096, 1024), (1, 4096), 0), out=buf270)
        del arg113_1
        del buf269
        buf272 = buf248; del buf248  # reuse
        # Topologically Sorted Source Nodes: [attn_output_47, hidden_states_56, hidden_states_58, hidden_states_59, to_96, pow_25, variance_24, add_72, rsqrt_24, mul_108, hidden_24, hidden_states_60], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf272, buf264, buf270, arg114_1, 5, 1024, stream=stream0)
        del arg114_1
        del buf264
        del buf270
    return (buf272, )


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1 = args
        args.clear()
        partition0_args = [arg0_1, arg1_1, arg3_1, arg2_1, arg6_1, arg7_1, arg8_1, arg9_1, arg4_1, arg5_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1]
        del arg0_1, arg1_1, arg3_1, arg2_1, arg6_1, arg7_1, arg8_1, arg9_1, arg4_1, arg5_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1
        (buf272,) = self.partitions[0](partition0_args)
        del partition0_args
        return (reinterpret_tensor(buf272, (1, 1, 1024), (5120, 1024, 1), 0), )

runner = Runner(partitions=[partition_0,])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = rand_strided((1, 1, 4, 64), (256, 4, 1, 4), device='cuda:0', dtype=torch.bfloat16)
    arg1_1 = rand_strided((1024, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((1, 1, 1, 1024), (1024, 1024, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg4_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg5_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg6_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg7_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg8_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg9_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg10_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg11_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg12_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg13_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg14_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg15_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg16_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg17_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg18_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg19_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg20_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg21_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg22_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg23_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg24_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg25_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg26_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg27_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg28_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg29_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg30_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg31_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg32_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg33_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg34_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg35_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg36_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg37_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg38_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg39_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg40_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg41_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg42_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg43_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg44_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg45_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg46_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg47_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg48_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg49_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg50_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg51_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg52_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg53_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg54_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg55_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg56_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg57_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg58_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg59_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg60_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg61_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg62_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg63_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg64_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg65_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg66_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg67_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg68_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg69_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg70_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg71_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg72_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg73_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg74_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg75_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg76_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg77_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg78_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg79_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg80_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg81_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg82_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg83_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg84_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg85_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg86_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg87_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg88_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg89_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg90_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg91_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg92_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg93_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg94_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg95_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg96_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg97_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg98_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg99_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg100_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg101_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg102_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg103_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg104_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg105_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg106_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg107_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg108_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg109_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg110_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg111_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg112_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg113_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg114_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
