# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
from torch._C import _cuda_getCurrentRawStream as get_raw_stream



# kernel path: /tmp/torchinductor_shingokuga/3g/c3glwnvb334zdkldm7peur4cxmsrxd43tadv46bp2tlyoh64iwmj.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten.expand, aten.view, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add => add
#   hidden => convert_element_type_4
#   hidden_states => mul_1
#   mul => mul
#   pow_1 => pow_1
#   rsqrt => rsqrt
#   special_tokens => expand
#   to => convert_element_type_3
#   variance => mean
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %addmm : Tensor "bf16[40, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %buf1 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf1]
#   %arg6_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg6_1]
#   %expand : Tensor "bf16[1, 10, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [1, 10, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, 10, 4, 1024][40960, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 10, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, 10, 5, 1024][51200, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [10, 5, 1024]), kwargs = {})
#   %convert_element_type_3 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_2, torch.float32), kwargs = {})
#   %pow_1 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_3, 2), kwargs = {})
#   %mean : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
#   %add : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
#   %rsqrt : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
#   %mul : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_2, %rsqrt), kwargs = {})
#   %convert_element_type_4 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
#   %mul_1 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_4, %arg6_1), kwargs = {})
#   return %buf1,%mul_1
triton_per_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0 = async_compile.triton('triton_per_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 290816}}
)
@triton.jit
def triton_per_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = (xindex % 5)
    r0_2 = r0_index
    x1 = xindex // 5
    x3 = xindex
    tmp24 = tl.load(in_ptr2 + (r0_2), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.where(tmp4, tmp5, tmp9)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = tmp11 * tmp11
    tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
    tmp15 = tl.where(xmask, tmp13, 0)
    tmp16 = tl.sum(tmp15, 1)[:, None].to(tl.float32)
    tmp17 = tl.full([1, 1], 1024.0, tl.float32)
    tmp18 = (tmp16 / tmp17)
    tmp19 = tl.full([1, 1], 1e-05, tl.float32)
    tmp20 = tmp18 + tmp19
    tmp21 = libdevice.rsqrt(tmp20)
    tmp22 = tmp11 * tmp21
    tmp23 = tmp22.to(tl.float32)
    tmp25 = tmp23 * tmp24
    tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp25, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/55/c55kgnye7alutaujl2oihdx462ox7hodcvyuqwbulcszrrozpxe5.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_12
#   k_embed => add_2
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_14
#   key_states_3 => clone_1
#   mul_2 => mul_2
#   mul_3 => mul_3
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_11
#   q_embed => add_1
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_13
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm : Tensor "bf16[50, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm]
#   %arg4_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %arg5_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %view_4 : Tensor "bf16[10, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [10, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[10, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [10, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[10, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [10, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[10, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [10, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_12 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_12, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg4_1, [%iota]), kwargs = {})
#   %mul_2 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg : Tensor "f32[10, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_3 : Tensor "f32[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_1 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_4 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_12, %index), kwargs = {})
#   %neg_1 : Tensor "f32[10, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_14 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_14,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[10, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [10, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[10, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [10, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf6
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 131072}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 1026560}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 102400
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x4 = xindex
    x0 = (xindex % 128)
    x2 = ((xindex // 2048) % 5)
    x5 = xindex // 128
    x1 = ((xindex // 128) % 16)
    x3 = xindex // 10240
    tmp0 = tl.load(in_ptr0 + (x4), None).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x5 + (x0)), tmp9, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x5 + ((-64) + x0)), tmp15, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 640*x1 + 10240*x3), tmp27, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/y5/cy5jwtm6gwpxcomrvfd2ndbrzfoz5p2nhelrkipkcofzyn4jtnyb.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_12
#   k_embed => add_2
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_14
#   key_states_3 => clone_1
#   mul_2 => mul_2
#   mul_3 => mul_3
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_11
#   q_embed => add_1
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_13
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm_1 : Tensor "bf16[50, 256][256, 1]cuda:0" = PlaceHolder[target=mm_1]
#   %arg4_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg4_1]
#   %arg5_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %view_4 : Tensor "bf16[10, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [10, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[10, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [10, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[10, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [10, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[10, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [10, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_12 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_12, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg4_1, [%iota]), kwargs = {})
#   %mul_2 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg : Tensor "f32[10, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_3 : Tensor "f32[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_1 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_4 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_12, %index), kwargs = {})
#   %neg_1 : Tensor "f32[10, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_14 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_14,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[10, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [10, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[10, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [10, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf7
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16384}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 130560}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 12800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x4 = xindex
    x0 = (xindex % 128)
    x2 = ((xindex // 256) % 5)
    x5 = xindex // 128
    x1 = ((xindex // 128) % 2)
    x3 = xindex // 1280
    tmp0 = tl.load(in_ptr0 + (x4), xmask).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x5 + (x0)), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x5 + ((-64) + x0)), tmp15 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 640*x1 + 1280*x3), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/pd/cpdb4h3wdyb46h3mgw2ujl2vzjmncm4jywgn5j247ummfu3pkfay.py
# Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_1 => cat_1
#   cat_2 => cat_2
#   chunk => split
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_12
#   k_embed => add_2
#   key_states => view_6
#   key_states_1 => permute_5
#   key_states_2 => convert_element_type_14
#   key_states_3 => clone_1
#   mul_2 => mul_2
#   mul_3 => mul_3
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg => neg
#   neg_1 => neg_1
#   position_ids => iota
#   q => convert_element_type_11
#   q_embed => add_1
#   query_states => view_4
#   query_states_1 => permute_4
#   query_states_2 => convert_element_type_13
#   query_states_3 => clone
#   sin => index_1
#   value_states => view_8
#   value_states_1 => permute_6
#   value_states_2 => clone_2
#   view => view_9
#   view_1 => view_10
#   view_2 => view_11
# Graph fragment:
#   %mm_2 : Tensor "bf16[50, 256][256, 1]cuda:0" = PlaceHolder[target=mm_2]
#   %view_4 : Tensor "bf16[10, 5, 2048][10240, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [10, 5, 2048]), kwargs = {})
#   %view_9 : Tensor "bf16[10, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_4, [10, 5, 16, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_9, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_11 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_11, 64, -1), kwargs = {})
#   %view_6 : Tensor "bf16[10, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [10, 5, 256]), kwargs = {})
#   %view_10 : Tensor "bf16[10, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [10, 5, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_10, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_12 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_5, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_12, 64, -1), kwargs = {})
#   %iota : Tensor "i64[5][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (5,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg4_1, [%iota]), kwargs = {})
#   %mul_2 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_11, %index), kwargs = {})
#   %neg : Tensor "f32[10, 16, 5, 64][5120, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_1 : Tensor "f32[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[5, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg5_1, [%iota]), kwargs = {})
#   %mul_3 : Tensor "f32[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_1 : Tensor "f32[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
#   %convert_element_type_13 : Tensor "bf16[10, 16, 5, 128][10240, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_1, torch.bfloat16), kwargs = {})
#   %clone : Tensor "bf16[10, 16, 5, 128][10240, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_13,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_4 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_12, %index), kwargs = {})
#   %neg_1 : Tensor "f32[10, 2, 5, 64][640, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_2 : Tensor "f32[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_14 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %clone_1 : Tensor "bf16[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_14,), kwargs = {memory_format: torch.contiguous_format})
#   %view_8 : Tensor "bf16[10, 5, 256][1280, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [10, 5, 256]), kwargs = {})
#   %view_11 : Tensor "bf16[10, 5, 2, 128][1280, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [10, 5, 2, 128]), kwargs = {})
#   %permute_6 : Tensor "bf16[10, 2, 5, 128][1280, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %clone_2 : Tensor "bf16[10, 2, 5, 128][1280, 640, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_6,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone, %clone_1, %clone_2), kwargs = {scale: 0.08838834764831843})
#   return %buf8
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16384}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 76800}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 12800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 5)
    x2 = ((xindex // 640) % 2)
    x3 = xindex // 1280
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 256*x1 + 1280*x3), xmask).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/tm/ctmg5myqx4p6qesgj6oz2xhx7kkbwdeg4ssgbd7ag53moumwcnga.py
# Topologically Sorted Source Nodes: [transpose_3, attn_output_1], Original ATen: [aten.transpose, aten.clone]
# Source node to ATen node mapping:
#   attn_output_1 => clone_3
#   transpose_3 => permute_7
# Graph fragment:
#   %getitem_4 : Tensor "bf16[10, 16, 5, 128][10240, 640, 128, 1]cuda:0" = PlaceHolder[target=getitem_4]
#   %permute_7 : Tensor "bf16[10, 5, 16, 128][10240, 128, 640, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%getitem_4, [0, 2, 1, 3]), kwargs = {})
#   %clone_3 : Tensor "bf16[10, 5, 16, 128][10240, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_7,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_3
triton_poi_fused_clone_transpose_4 = async_compile.triton('triton_poi_fused_clone_transpose_4', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 131072}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_transpose_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 614400}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_transpose_4(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 102400
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 16)
    x2 = ((xindex // 2048) % 5)
    x3 = xindex // 10240
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 640*x1 + 10240*x3), None).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yz/cyznavro5e22jgkrqsqpugoqffc4ecu35fo7geugtty6z6rp2nri.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_4 => add_4
#   attn_output_3 => view_14
#   hidden_1 => convert_element_type_18
#   hidden_states_1 => add_3
#   hidden_states_2 => mul_7
#   mul_6 => mul_6
#   pow_2 => pow_2
#   rsqrt_1 => rsqrt_1
#   special_tokens => expand
#   to_6 => convert_element_type_17
#   variance_1 => mean_1
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %addmm : Tensor "bf16[40, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %mm_3 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %buf17 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf17]
#   %arg11_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg11_1]
#   %expand : Tensor "bf16[1, 10, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [1, 10, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, 10, 4, 1024][40960, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 10, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, 10, 5, 1024][51200, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [10, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [10, 5, 1024]), kwargs = {})
#   %add_3 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %convert_element_type_17 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_3, torch.float32), kwargs = {})
#   %pow_2 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_17, 2), kwargs = {})
#   %mean_1 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
#   %add_4 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
#   %rsqrt_1 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_4,), kwargs = {})
#   %mul_6 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_3, %rsqrt_1), kwargs = {})
#   %convert_element_type_18 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_6, torch.bfloat16), kwargs = {})
#   %mul_7 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_18, %arg11_1), kwargs = {})
#   return %buf17,%mul_7
triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 393216}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = (xindex % 5)
    r0_2 = r0_index
    x1 = xindex // 5
    x3 = xindex
    tmp11 = tl.load(in_ptr2 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp26 = tl.load(in_ptr3 + (r0_2), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.where(tmp4, tmp5, tmp9)
    tmp12 = tmp10 + tmp11
    tmp13 = tmp12.to(tl.float32)
    tmp14 = tmp13 * tmp13
    tmp15 = tl.broadcast_to(tmp14, [XBLOCK, R0_BLOCK])
    tmp17 = tl.where(xmask, tmp15, 0)
    tmp18 = tl.sum(tmp17, 1)[:, None].to(tl.float32)
    tmp19 = tl.full([1, 1], 1024.0, tl.float32)
    tmp20 = (tmp18 / tmp19)
    tmp21 = tl.full([1, 1], 1e-05, tl.float32)
    tmp22 = tmp20 + tmp21
    tmp23 = libdevice.rsqrt(tmp22)
    tmp24 = tmp13 * tmp23
    tmp25 = tmp24.to(tl.float32)
    tmp27 = tmp25 * tmp26
    tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zv/czvcawxtlkkawoplkibo3wadyx7u7jr4tn7fz6n6t6hnzvv775bl.py
# Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
# Source node to ATen node mapping:
#   linear_5 => view_16
#   linear_6 => view_18
#   mul_8 => mul_8
#   silu => add_5, convert_element_type_21, convert_element_type_22, div, exp, neg_2
# Graph fragment:
#   %mm_4 : Tensor "bf16[50, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_4]
#   %mm_5 : Tensor "bf16[50, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_5]
#   %view_16 : Tensor "bf16[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_4, [10, 5, 4096]), kwargs = {})
#   %convert_element_type_21 : Tensor "f32[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_16, torch.float32), kwargs = {})
#   %neg_2 : Tensor "f32[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%convert_element_type_21,), kwargs = {})
#   %exp : Tensor "f32[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%neg_2,), kwargs = {})
#   %add_5 : Tensor "f32[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%exp, 1), kwargs = {})
#   %div : Tensor "f32[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_21, %add_5), kwargs = {})
#   %convert_element_type_22 : Tensor "bf16[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%div, torch.bfloat16), kwargs = {})
#   %view_18 : Tensor "bf16[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_5, [10, 5, 4096]), kwargs = {})
#   %mul_8 : Tensor "bf16[10, 5, 4096][20480, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_22, %view_18), kwargs = {})
#   return %mul_8
triton_poi_fused__unsafe_view_mul_silu_6 = async_compile.triton('triton_poi_fused__unsafe_view_mul_silu_6', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 262144}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__unsafe_view_mul_silu_6', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 1638400}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__unsafe_view_mul_silu_6(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 204800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = -tmp1
    tmp3 = libdevice.exp(tmp2)
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = (tmp1 / tmp5)
    tmp7 = tmp6.to(tl.float32)
    tmp9 = tmp7 * tmp8
    tl.store(in_out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/wc/cwcugfllrrk4tdxjj5pykwvvyl3uorikypmuew5onxzeir3vd6nh.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_6 => add_7
#   attn_output_3 => view_14
#   hidden_2 => convert_element_type_28
#   hidden_states_1 => add_3
#   hidden_states_3 => view_20
#   hidden_states_4 => add_6
#   hidden_states_5 => mul_10
#   mul_9 => mul_9
#   pow_3 => pow_3
#   rsqrt_2 => rsqrt_2
#   special_tokens => expand
#   to_8 => convert_element_type_27
#   variance_2 => mean_2
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %addmm : Tensor "bf16[40, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %mm_3 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %buf23 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf23]
#   %arg15_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg15_1]
#   %expand : Tensor "bf16[1, 10, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [1, 10, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, 10, 4, 1024][40960, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 10, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, 10, 5, 1024][51200, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [10, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [10, 5, 1024]), kwargs = {})
#   %add_3 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %view_20 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [10, 5, 1024]), kwargs = {})
#   %add_6 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %view_20), kwargs = {})
#   %convert_element_type_27 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_6, torch.float32), kwargs = {})
#   %pow_3 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_27, 2), kwargs = {})
#   %mean_2 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
#   %add_7 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
#   %rsqrt_2 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_7,), kwargs = {})
#   %mul_9 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_6, %rsqrt_2), kwargs = {})
#   %convert_element_type_28 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_9, torch.bfloat16), kwargs = {})
#   %mul_10 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_28, %arg15_1), kwargs = {})
#   return %buf23,%mul_10
triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 495616}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = (xindex % 5)
    r0_2 = r0_index
    x1 = xindex // 5
    x3 = xindex
    tmp11 = tl.load(in_ptr2 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp13 = tl.load(in_ptr3 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp28 = tl.load(in_ptr4 + (r0_2), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.where(tmp4, tmp5, tmp9)
    tmp12 = tmp10 + tmp11
    tmp14 = tmp12 + tmp13
    tmp15 = tmp14.to(tl.float32)
    tmp16 = tmp15 * tmp15
    tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
    tmp19 = tl.where(xmask, tmp17, 0)
    tmp20 = tl.sum(tmp19, 1)[:, None].to(tl.float32)
    tmp21 = tl.full([1, 1], 1024.0, tl.float32)
    tmp22 = (tmp20 / tmp21)
    tmp23 = tl.full([1, 1], 1e-05, tl.float32)
    tmp24 = tmp22 + tmp23
    tmp25 = libdevice.rsqrt(tmp24)
    tmp26 = tmp15 * tmp25
    tmp27 = tmp26.to(tl.float32)
    tmp29 = tmp27 * tmp28
    tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp29, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5q/c5qeyo5ccsjrtarwml6lucxzaete6haz3nivgr6nxskfzcq5pqfz.py
# Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_10 => add_11
#   attn_output_3 => view_14
#   attn_output_7 => view_32
#   hidden_3 => convert_element_type_42
#   hidden_states_1 => add_3
#   hidden_states_3 => view_20
#   hidden_states_4 => add_6
#   hidden_states_6 => add_10
#   hidden_states_7 => mul_16
#   mul_15 => mul_15
#   pow_4 => pow_4
#   rsqrt_3 => rsqrt_3
#   special_tokens => expand
#   to_14 => convert_element_type_41
#   variance_3 => mean_3
#   x => view_1
#   x_1 => cat
#   x_2 => view_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 1, 1, 1024][1024, 1024, 1024, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %addmm : Tensor "bf16[40, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %mm_3 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %add_10 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_10]
#   %buf40 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf40]
#   %arg20_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg20_1]
#   %expand : Tensor "bf16[1, 10, 1, 1024][1024, 0, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%arg3_1, [1, 10, 1, -1]), kwargs = {})
#   %view_1 : Tensor "bf16[1, 10, 4, 1024][40960, 4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [1, 10, 4, 1024]), kwargs = {})
#   %cat : Tensor "bf16[1, 10, 5, 1024][51200, 5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%expand, %view_1], 2), kwargs = {})
#   %view_2 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [10, 5, 1024]), kwargs = {})
#   %view_14 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [10, 5, 1024]), kwargs = {})
#   %add_3 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_2, %view_14), kwargs = {})
#   %view_20 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [10, 5, 1024]), kwargs = {})
#   %add_6 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %view_20), kwargs = {})
#   %view_32 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_10, [10, 5, 1024]), kwargs = {})
#   %add_10 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_6, %view_32), kwargs = {})
#   %convert_element_type_41 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_10, torch.float32), kwargs = {})
#   %pow_4 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_41, 2), kwargs = {})
#   %mean_3 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
#   %add_11 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
#   %rsqrt_3 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_11,), kwargs = {})
#   %mul_15 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_10, %rsqrt_3), kwargs = {})
#   %convert_element_type_42 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_15, torch.bfloat16), kwargs = {})
#   %mul_16 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_42, %arg20_1), kwargs = {})
#   return %add_10,%buf40,%mul_16
triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 802816}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = (xindex % 5)
    r0_2 = r0_index
    x1 = xindex // 5
    x3 = xindex
    tmp11 = tl.load(in_out_ptr0 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp13 = tl.load(in_ptr2 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp15 = tl.load(in_ptr3 + (r0_2 + 1024*x3), xmask, other=0.0).to(tl.float32)
    tmp30 = tl.load(in_ptr4 + (r0_2), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 1, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (tl.broadcast_to(r0_2, [XBLOCK, R0_BLOCK])), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 5, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tl.load(in_ptr1 + (r0_2 + 1024*((-1) + x0) + 4096*x1), tmp6 & xmask, other=0.0).to(tl.float32)
    tmp10 = tl.where(tmp4, tmp5, tmp9)
    tmp12 = tmp10 + tmp11
    tmp14 = tmp12 + tmp13
    tmp16 = tmp14 + tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp18 = tmp17 * tmp17
    tmp19 = tl.broadcast_to(tmp18, [XBLOCK, R0_BLOCK])
    tmp21 = tl.where(xmask, tmp19, 0)
    tmp22 = tl.sum(tmp21, 1)[:, None].to(tl.float32)
    tmp23 = tl.full([1, 1], 1024.0, tl.float32)
    tmp24 = (tmp22 / tmp23)
    tmp25 = tl.full([1, 1], 1e-05, tl.float32)
    tmp26 = tmp24 + tmp25
    tmp27 = libdevice.rsqrt(tmp26)
    tmp28 = tmp17 * tmp27
    tmp29 = tmp28.to(tl.float32)
    tmp31 = tmp29 * tmp30
    tl.store(in_out_ptr0 + (r0_2 + 1024*x3), tmp16, xmask)
    tl.store(out_ptr1 + (r0_2 + 1024*x3), tmp31, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ee/ceejh5gzih2rtl2njpiv77swtmfp7gqv5af2tr2otndj46twie63.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_12 => add_14
#   hidden_4 => convert_element_type_52
#   hidden_states_10 => mul_19
#   hidden_states_8 => view_38
#   hidden_states_9 => add_13
#   mul_18 => mul_18
#   pow_5 => pow_5
#   rsqrt_4 => rsqrt_4
#   to_16 => convert_element_type_51
#   variance_4 => mean_4
# Graph fragment:
#   %add_10 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_10]
#   %mm_13 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %buf46 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf46]
#   %arg24_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg24_1]
#   %view_38 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [10, 5, 1024]), kwargs = {})
#   %add_13 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_10, %view_38), kwargs = {})
#   %convert_element_type_51 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_13, torch.float32), kwargs = {})
#   %pow_5 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_51, 2), kwargs = {})
#   %mean_4 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
#   %add_14 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
#   %rsqrt_4 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_14,), kwargs = {})
#   %mul_18 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_13, %rsqrt_4), kwargs = {})
#   %convert_element_type_52 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_18, torch.bfloat16), kwargs = {})
#   %mul_19 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_52, %arg24_1), kwargs = {})
#   return %buf46,%mul_19
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 411648}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp3 * tmp3
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
    tmp7 = tl.where(xmask, tmp5, 0)
    tmp8 = tl.sum(tmp7, 1)[:, None].to(tl.float32)
    tmp9 = tl.full([1, 1], 1024.0, tl.float32)
    tmp10 = (tmp8 / tmp9)
    tmp11 = tl.full([1, 1], 1e-05, tl.float32)
    tmp12 = tmp10 + tmp11
    tmp13 = libdevice.rsqrt(tmp12)
    tmp14 = tmp3 * tmp13
    tmp15 = tmp14.to(tl.float32)
    tmp17 = tmp15 * tmp16
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp17, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qk/cqkuno7afzae7eq3xp5cpjvq3snkliaqhn3aovwrqeqgqiqme73z.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, to_22, pow_6, variance_5, add_16, rsqrt_5, mul_24, hidden_5, hidden_states_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_16 => add_18
#   attn_output_11 => view_50
#   hidden_5 => convert_element_type_66
#   hidden_states_11 => add_17
#   hidden_states_12 => mul_25
#   hidden_states_8 => view_38
#   hidden_states_9 => add_13
#   mul_24 => mul_24
#   pow_6 => pow_6
#   rsqrt_5 => rsqrt_5
#   to_22 => convert_element_type_65
#   variance_5 => mean_5
# Graph fragment:
#   %add_10 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_10]
#   %mm_13 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %mm_17 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %buf62 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf62]
#   %arg29_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg29_1]
#   %view_38 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [10, 5, 1024]), kwargs = {})
#   %add_13 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_10, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_17, [10, 5, 1024]), kwargs = {})
#   %add_17 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_13, %view_50), kwargs = {})
#   %convert_element_type_65 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_17, torch.float32), kwargs = {})
#   %pow_6 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_65, 2), kwargs = {})
#   %mean_5 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_6, [-1], True), kwargs = {})
#   %add_18 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_5, 1e-05), kwargs = {})
#   %rsqrt_5 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_18,), kwargs = {})
#   %mul_24 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_17, %rsqrt_5), kwargs = {})
#   %convert_element_type_66 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_24, torch.bfloat16), kwargs = {})
#   %mul_25 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_66, %arg29_1), kwargs = {})
#   return %buf62,%mul_25
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 514048}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp18 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp5 = tmp4.to(tl.float32)
    tmp6 = tmp5 * tmp5
    tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
    tmp9 = tl.where(xmask, tmp7, 0)
    tmp10 = tl.sum(tmp9, 1)[:, None].to(tl.float32)
    tmp11 = tl.full([1, 1], 1024.0, tl.float32)
    tmp12 = (tmp10 / tmp11)
    tmp13 = tl.full([1, 1], 1e-05, tl.float32)
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp5 * tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp19 = tmp17 * tmp18
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp19, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/bg/cbgqfsrw6r4irxsyqztmt4ej3b5qtrlpvkeplmhuru5p2k6fy3mz.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_18 => add_21
#   attn_output_11 => view_50
#   hidden_6 => convert_element_type_76
#   hidden_states_11 => add_17
#   hidden_states_13 => view_56
#   hidden_states_14 => add_20
#   hidden_states_15 => mul_28
#   hidden_states_8 => view_38
#   hidden_states_9 => add_13
#   mul_27 => mul_27
#   pow_7 => pow_7
#   rsqrt_6 => rsqrt_6
#   to_24 => convert_element_type_75
#   variance_6 => mean_6
# Graph fragment:
#   %add_10 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_10]
#   %mm_13 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %mm_17 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %mm_20 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_20]
#   %buf68 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf68]
#   %arg33_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg33_1]
#   %view_38 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [10, 5, 1024]), kwargs = {})
#   %add_13 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_10, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_17, [10, 5, 1024]), kwargs = {})
#   %add_17 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_13, %view_50), kwargs = {})
#   %view_56 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_20, [10, 5, 1024]), kwargs = {})
#   %add_20 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_17, %view_56), kwargs = {})
#   %convert_element_type_75 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_20, torch.float32), kwargs = {})
#   %pow_7 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_75, 2), kwargs = {})
#   %mean_6 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_7, [-1], True), kwargs = {})
#   %add_21 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_6, 1e-05), kwargs = {})
#   %rsqrt_6 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_21,), kwargs = {})
#   %mul_27 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_20, %rsqrt_6), kwargs = {})
#   %convert_element_type_76 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.bfloat16), kwargs = {})
#   %mul_28 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_76, %arg33_1), kwargs = {})
#   return %buf68,%mul_28
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 616448}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp20 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tmp7 * tmp7
    tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
    tmp11 = tl.where(xmask, tmp9, 0)
    tmp12 = tl.sum(tmp11, 1)[:, None].to(tl.float32)
    tmp13 = tl.full([1, 1], 1024.0, tl.float32)
    tmp14 = (tmp12 / tmp13)
    tmp15 = tl.full([1, 1], 1e-05, tl.float32)
    tmp16 = tmp14 + tmp15
    tmp17 = libdevice.rsqrt(tmp16)
    tmp18 = tmp7 * tmp17
    tmp19 = tmp18.to(tl.float32)
    tmp21 = tmp19 * tmp20
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp21, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ro/croemwexguw3howo3hlrxjdpke7wwy34mwccjr23b23rohmu5m76.py
# Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, attn_output_15, hidden_states_16, to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_22 => add_25
#   attn_output_11 => view_50
#   attn_output_15 => view_68
#   hidden_7 => convert_element_type_90
#   hidden_states_11 => add_17
#   hidden_states_13 => view_56
#   hidden_states_14 => add_20
#   hidden_states_16 => add_24
#   hidden_states_17 => mul_34
#   hidden_states_8 => view_38
#   hidden_states_9 => add_13
#   mul_33 => mul_33
#   pow_8 => pow_8
#   rsqrt_7 => rsqrt_7
#   to_30 => convert_element_type_89
#   variance_7 => mean_7
# Graph fragment:
#   %add_10 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_10]
#   %mm_13 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %mm_17 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %mm_20 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_20]
#   %mm_24 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_24]
#   %add_24 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_24]
#   %buf85 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf85]
#   %arg38_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg38_1]
#   %view_38 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [10, 5, 1024]), kwargs = {})
#   %add_13 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_10, %view_38), kwargs = {})
#   %view_50 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_17, [10, 5, 1024]), kwargs = {})
#   %add_17 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_13, %view_50), kwargs = {})
#   %view_56 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_20, [10, 5, 1024]), kwargs = {})
#   %add_20 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_17, %view_56), kwargs = {})
#   %view_68 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_24, [10, 5, 1024]), kwargs = {})
#   %add_24 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_20, %view_68), kwargs = {})
#   %convert_element_type_89 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_24, torch.float32), kwargs = {})
#   %pow_8 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_89, 2), kwargs = {})
#   %mean_7 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_8, [-1], True), kwargs = {})
#   %add_25 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_7, 1e-05), kwargs = {})
#   %rsqrt_7 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_25,), kwargs = {})
#   %mul_33 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_24, %rsqrt_7), kwargs = {})
#   %convert_element_type_90 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_33, torch.bfloat16), kwargs = {})
#   %mul_34 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_90, %arg38_1), kwargs = {})
#   return %add_24,%buf85,%mul_34
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 923648}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp22 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp8 = tmp6 + tmp7
    tmp9 = tmp8.to(tl.float32)
    tmp10 = tmp9 * tmp9
    tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
    tmp13 = tl.where(xmask, tmp11, 0)
    tmp14 = tl.sum(tmp13, 1)[:, None].to(tl.float32)
    tmp15 = tl.full([1, 1], 1024.0, tl.float32)
    tmp16 = (tmp14 / tmp15)
    tmp17 = tl.full([1, 1], 1e-05, tl.float32)
    tmp18 = tmp16 + tmp17
    tmp19 = libdevice.rsqrt(tmp18)
    tmp20 = tmp9 * tmp19
    tmp21 = tmp20.to(tl.float32)
    tmp23 = tmp21 * tmp22
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp8, xmask)
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp23, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ej/cejqgy72vpmfmswr5iimevsvi27wwk7fwzm3j3et365zoikzs7rz.py
# Topologically Sorted Source Nodes: [hidden_states_58, hidden_states_59, to_96, pow_25, variance_24, add_72, rsqrt_24, mul_108, hidden_24, hidden_states_60], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_72 => add_84
#   hidden_24 => convert_element_type_292
#   hidden_states_58 => view_218
#   hidden_states_59 => add_83
#   hidden_states_60 => mul_109
#   mul_108 => mul_108
#   pow_25 => pow_25
#   rsqrt_24 => rsqrt_24
#   to_96 => convert_element_type_291
#   variance_24 => mean_24
# Graph fragment:
#   %add_80 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0" = PlaceHolder[target=add_80]
#   %mm_83 : Tensor "bf16[50, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_83]
#   %buf271 : Tensor "f32[10, 5, 1][5, 1, 50]cuda:0" = PlaceHolder[target=buf271]
#   %arg114_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg114_1]
#   %view_218 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_83, [10, 5, 1024]), kwargs = {})
#   %add_83 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_80, %view_218), kwargs = {})
#   %convert_element_type_291 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_83, torch.float32), kwargs = {})
#   %pow_25 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_291, 2), kwargs = {})
#   %mean_24 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_25, [-1], True), kwargs = {})
#   %add_84 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_24, 1e-05), kwargs = {})
#   %rsqrt_24 : Tensor "f32[10, 5, 1][5, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_84,), kwargs = {})
#   %mul_108 : Tensor "f32[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_83, %rsqrt_24), kwargs = {})
#   %convert_element_type_292 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_108, torch.bfloat16), kwargs = {})
#   %mul_109 : Tensor "bf16[10, 5, 1024][5120, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_292, %arg114_1), kwargs = {})
#   return %buf271,%mul_109
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 64, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 411648}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 50
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp16 = tl.load(in_ptr1 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp3 * tmp3
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
    tmp7 = tl.where(xmask, tmp5, 0)
    tmp8 = tl.sum(tmp7, 1)[:, None].to(tl.float32)
    tmp9 = tl.full([1, 1], 1024.0, tl.float32)
    tmp10 = (tmp8 / tmp9)
    tmp11 = tl.full([1, 1], 1e-05, tl.float32)
    tmp12 = tmp10 + tmp11
    tmp13 = libdevice.rsqrt(tmp12)
    tmp14 = tmp3 * tmp13
    tmp15 = tmp14.to(tl.float32)
    tmp17 = tmp15 * tmp16
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp17, xmask)
''', device_str='cuda')

def partition_0(args):
    arg2_1, arg0_1, arg1_1, arg3_1, arg6_1, arg7_1, arg8_1, arg9_1, arg4_1, arg5_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1 = args
    args.clear()
    assert_size_stride(arg2_1, (1024, ), (1, ))
    assert_size_stride(arg0_1, (1, 10, 4, 64), (2560, 256, 64, 1))
    assert_size_stride(arg1_1, (1024, 64), (64, 1))
    assert_size_stride(arg3_1, (1, 1, 1, 1024), (1024, 1024, 1024, 1))
    assert_size_stride(arg6_1, (1024, ), (1, ))
    assert_size_stride(arg7_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg8_1, (256, 1024), (1024, 1))
    assert_size_stride(arg9_1, (256, 1024), (1024, 1))
    assert_size_stride(arg4_1, (32768, 128), (128, 1))
    assert_size_stride(arg5_1, (32768, 128), (128, 1))
    assert_size_stride(arg10_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg11_1, (1024, ), (1, ))
    assert_size_stride(arg12_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg13_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg14_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg15_1, (1024, ), (1, ))
    assert_size_stride(arg16_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg17_1, (256, 1024), (1024, 1))
    assert_size_stride(arg18_1, (256, 1024), (1024, 1))
    assert_size_stride(arg19_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg20_1, (1024, ), (1, ))
    assert_size_stride(arg21_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg22_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg23_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg24_1, (1024, ), (1, ))
    assert_size_stride(arg25_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg26_1, (256, 1024), (1024, 1))
    assert_size_stride(arg27_1, (256, 1024), (1024, 1))
    assert_size_stride(arg28_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg29_1, (1024, ), (1, ))
    assert_size_stride(arg30_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg31_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg32_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg33_1, (1024, ), (1, ))
    assert_size_stride(arg34_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg35_1, (256, 1024), (1024, 1))
    assert_size_stride(arg36_1, (256, 1024), (1024, 1))
    assert_size_stride(arg37_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg38_1, (1024, ), (1, ))
    assert_size_stride(arg39_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg40_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg41_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg42_1, (1024, ), (1, ))
    assert_size_stride(arg43_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg44_1, (256, 1024), (1024, 1))
    assert_size_stride(arg45_1, (256, 1024), (1024, 1))
    assert_size_stride(arg46_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg47_1, (1024, ), (1, ))
    assert_size_stride(arg48_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg49_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg50_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg51_1, (1024, ), (1, ))
    assert_size_stride(arg52_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg53_1, (256, 1024), (1024, 1))
    assert_size_stride(arg54_1, (256, 1024), (1024, 1))
    assert_size_stride(arg55_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg56_1, (1024, ), (1, ))
    assert_size_stride(arg57_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg58_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg59_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg60_1, (1024, ), (1, ))
    assert_size_stride(arg61_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg62_1, (256, 1024), (1024, 1))
    assert_size_stride(arg63_1, (256, 1024), (1024, 1))
    assert_size_stride(arg64_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg65_1, (1024, ), (1, ))
    assert_size_stride(arg66_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg67_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg68_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg69_1, (1024, ), (1, ))
    assert_size_stride(arg70_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg71_1, (256, 1024), (1024, 1))
    assert_size_stride(arg72_1, (256, 1024), (1024, 1))
    assert_size_stride(arg73_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg74_1, (1024, ), (1, ))
    assert_size_stride(arg75_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg76_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg77_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg78_1, (1024, ), (1, ))
    assert_size_stride(arg79_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg80_1, (256, 1024), (1024, 1))
    assert_size_stride(arg81_1, (256, 1024), (1024, 1))
    assert_size_stride(arg82_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg83_1, (1024, ), (1, ))
    assert_size_stride(arg84_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg85_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg86_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg87_1, (1024, ), (1, ))
    assert_size_stride(arg88_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg89_1, (256, 1024), (1024, 1))
    assert_size_stride(arg90_1, (256, 1024), (1024, 1))
    assert_size_stride(arg91_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg92_1, (1024, ), (1, ))
    assert_size_stride(arg93_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg94_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg95_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg96_1, (1024, ), (1, ))
    assert_size_stride(arg97_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg98_1, (256, 1024), (1024, 1))
    assert_size_stride(arg99_1, (256, 1024), (1024, 1))
    assert_size_stride(arg100_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg101_1, (1024, ), (1, ))
    assert_size_stride(arg102_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg103_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg104_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg105_1, (1024, ), (1, ))
    assert_size_stride(arg106_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg107_1, (256, 1024), (1024, 1))
    assert_size_stride(arg108_1, (256, 1024), (1024, 1))
    assert_size_stride(arg109_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg110_1, (1024, ), (1, ))
    assert_size_stride(arg111_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg112_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg113_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg114_1, (1024, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((40, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [x], Original ATen: [aten.view, aten.t, aten.addmm]
        extern_kernels.addmm(arg2_1, reinterpret_tensor(arg0_1, (40, 64), (64, 1), 0), reinterpret_tensor(arg1_1, (64, 1024), (1, 64), 0), alpha=1, beta=1, out=buf0)
        del arg0_1
        del arg1_1
        del arg2_1
        buf2 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten.expand, aten.view, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy_add_cat_expand_mean_mul_pow_rsqrt_view_0.run(arg3_1, buf0, arg6_1, buf2, 50, 1024, stream=stream0)
        del arg6_1
        buf3 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states, query_states], Original ATen: [aten.expand, aten.view, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg7_1, (1024, 2048), (1, 1024), 0), out=buf3)
        del arg7_1
        buf4 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg8_1, (1024, 256), (1, 1024), 0), out=buf4)
        del arg8_1
        buf5 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf2, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg9_1, (1024, 256), (1, 1024), 0), out=buf5)
        del arg9_1
        buf6 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf3, arg4_1, arg5_1, buf6, 102400, stream=stream0)
        del buf3
        buf7 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf4, arg4_1, arg5_1, buf7, 12800, stream=stream0)
        buf8 = reinterpret_tensor(buf4, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf4  # reuse
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf5, buf8, 12800, stream=stream0)
        del buf5
        # Topologically Sorted Source Nodes: [query_states, view, query_states_1, q, chunk, key_states, view_1, key_states_1, k, chunk_1, position_ids, cos, mul_2, neg, cat_1, sin, mul_3, q_embed, query_states_2, query_states_3, mul_4, neg_1, cat_2, mul_5, k_embed, key_states_2, key_states_3, value_states, view_2, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf9 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf6, buf7, buf8, scale=0.08838834764831843)
        del buf7
        del buf8
        buf10 = buf9[0]
        assert_size_stride(buf10, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf10, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf9
        buf15 = reinterpret_tensor(buf6, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf6  # reuse
        # Topologically Sorted Source Nodes: [transpose_3, attn_output_1], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf10, buf15, 102400, stream=stream0)
        del buf10
        buf16 = reinterpret_tensor(buf2, (50, 1024), (1024, 1), 0); del buf2  # reuse
        # Topologically Sorted Source Nodes: [transpose_3, attn_output_1, attn_output_2, attn_output_3], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf15, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg10_1, (2048, 1024), (1, 2048), 0), out=buf16)
        del arg10_1
        del buf15
        buf18 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_5.run(arg3_1, buf0, buf16, arg11_1, buf18, 50, 1024, stream=stream0)
        del arg11_1
        buf19 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_5], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf18, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg12_1, (1024, 4096), (1, 1024), 0), out=buf19)
        del arg12_1
        buf20 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_6], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf18, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg13_1, (1024, 4096), (1, 1024), 0), out=buf20)
        del arg13_1
        buf21 = reinterpret_tensor(buf19, (10, 5, 4096), (20480, 4096, 1), 0); del buf19  # reuse
        # Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf21, buf20, 204800, stream=stream0)
        del buf20
        buf22 = reinterpret_tensor(buf18, (50, 1024), (1024, 1), 0); del buf18  # reuse
        # Topologically Sorted Source Nodes: [linear_5, silu, linear_6, mul_8, hidden_states_3], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf21, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg14_1, (4096, 1024), (1, 4096), 0), out=buf22)
        del arg14_1
        buf24 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_7.run(arg3_1, buf0, buf16, buf22, arg15_1, buf24, 50, 1024, stream=stream0)
        del arg15_1
        buf25 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_4], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf24, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg16_1, (1024, 2048), (1, 1024), 0), out=buf25)
        del arg16_1
        buf26 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_4], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf24, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg17_1, (1024, 256), (1, 1024), 0), out=buf26)
        del arg17_1
        buf27 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_3], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf24, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg18_1, (1024, 256), (1, 1024), 0), out=buf27)
        del arg18_1
        buf28 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf25, arg4_1, arg5_1, buf28, 102400, stream=stream0)
        del buf25
        buf29 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf26, arg4_1, arg5_1, buf29, 12800, stream=stream0)
        buf30 = reinterpret_tensor(buf26, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf26  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf27, buf30, 12800, stream=stream0)
        del buf27
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_4, view_3, query_states_5, q_1, chunk_2, key_states_4, view_4, key_states_5, k_1, chunk_3, mul_11, neg_2, cat_3, mul_12, q_embed_1, query_states_6, query_states_7, mul_13, neg_3, cat_4, mul_14, k_embed_1, key_states_6, key_states_7, value_states_3, view_5, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf31 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf28, buf29, buf30, scale=0.08838834764831843)
        del buf29
        del buf30
        buf32 = buf31[0]
        assert_size_stride(buf32, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf32, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf31
        buf37 = reinterpret_tensor(buf28, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf28  # reuse
        # Topologically Sorted Source Nodes: [transpose_7, attn_output_5], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf32, buf37, 102400, stream=stream0)
        del buf32
        buf38 = reinterpret_tensor(buf24, (50, 1024), (1024, 1), 0); del buf24  # reuse
        # Topologically Sorted Source Nodes: [transpose_7, attn_output_5, attn_output_6, attn_output_7], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf37, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg19_1, (2048, 1024), (1, 2048), 0), out=buf38)
        del arg19_1
        del buf37
        buf39 = reinterpret_tensor(buf16, (10, 5, 1024), (5120, 1024, 1), 0); del buf16  # reuse
        buf41 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [special_tokens, x, x_1, x_2, attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten.expand, aten.view, aten.cat, aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_cat_expand_mean_mul_pow_rsqrt_view_8.run(buf39, arg3_1, buf0, buf22, buf38, arg20_1, buf41, 50, 1024, stream=stream0)
        del arg20_1
        del arg3_1
        del buf0
        del buf22
        buf42 = reinterpret_tensor(buf21, (50, 4096), (4096, 1), 0); del buf21  # reuse
        # Topologically Sorted Source Nodes: [to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7, linear_12], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf41, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg21_1, (1024, 4096), (1, 1024), 0), out=buf42)
        del arg21_1
        buf43 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_13], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf41, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg22_1, (1024, 4096), (1, 1024), 0), out=buf43)
        del arg22_1
        buf44 = reinterpret_tensor(buf42, (10, 5, 4096), (20480, 4096, 1), 0); del buf42  # reuse
        # Topologically Sorted Source Nodes: [linear_12, silu_1, linear_13, mul_17], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf44, buf43, 204800, stream=stream0)
        del buf43
        buf45 = reinterpret_tensor(buf41, (50, 1024), (1024, 1), 0); del buf41  # reuse
        # Topologically Sorted Source Nodes: [linear_12, silu_1, linear_13, mul_17, hidden_states_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf44, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg23_1, (4096, 1024), (1, 4096), 0), out=buf45)
        del arg23_1
        buf47 = reinterpret_tensor(buf38, (10, 5, 1024), (5120, 1024, 1), 0); del buf38  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf39, buf45, arg24_1, buf47, 50, 1024, stream=stream0)
        del arg24_1
        buf48 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10, query_states_8], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg25_1, (1024, 2048), (1, 1024), 0), out=buf48)
        del arg25_1
        buf49 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_8], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg26_1, (1024, 256), (1, 1024), 0), out=buf49)
        del arg26_1
        buf50 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_6], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf47, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg27_1, (1024, 256), (1, 1024), 0), out=buf50)
        del arg27_1
        buf51 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf48, arg4_1, arg5_1, buf51, 102400, stream=stream0)
        del buf48
        buf52 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf49, arg4_1, arg5_1, buf52, 12800, stream=stream0)
        buf53 = reinterpret_tensor(buf49, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf49  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf50, buf53, 12800, stream=stream0)
        del buf50
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_8, view_6, query_states_9, q_2, chunk_4, key_states_8, view_7, key_states_9, k_2, chunk_5, mul_20, neg_4, cat_5, mul_21, q_embed_2, query_states_10, query_states_11, mul_22, neg_5, cat_6, mul_23, k_embed_2, key_states_10, key_states_11, value_states_6, view_8, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf54 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf51, buf52, buf53, scale=0.08838834764831843)
        del buf52
        del buf53
        buf55 = buf54[0]
        assert_size_stride(buf55, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf55, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf54
        buf60 = reinterpret_tensor(buf51, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf51  # reuse
        # Topologically Sorted Source Nodes: [transpose_11, attn_output_9], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf55, buf60, 102400, stream=stream0)
        del buf55
        buf61 = reinterpret_tensor(buf47, (50, 1024), (1024, 1), 0); del buf47  # reuse
        # Topologically Sorted Source Nodes: [transpose_11, attn_output_9, attn_output_10, attn_output_11], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf60, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg28_1, (2048, 1024), (1, 2048), 0), out=buf61)
        del arg28_1
        del buf60
        buf63 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, to_22, pow_6, variance_5, add_16, rsqrt_5, mul_24, hidden_5, hidden_states_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf39, buf45, buf61, arg29_1, buf63, 50, 1024, stream=stream0)
        del arg29_1
        buf64 = reinterpret_tensor(buf44, (50, 4096), (4096, 1), 0); del buf44  # reuse
        # Topologically Sorted Source Nodes: [linear_19], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf63, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg30_1, (1024, 4096), (1, 1024), 0), out=buf64)
        del arg30_1
        buf65 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf63, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg31_1, (1024, 4096), (1, 1024), 0), out=buf65)
        del arg31_1
        del buf63
        buf66 = reinterpret_tensor(buf64, (10, 5, 4096), (20480, 4096, 1), 0); del buf64  # reuse
        # Topologically Sorted Source Nodes: [linear_19, silu_2, linear_20, mul_26], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf66, buf65, 204800, stream=stream0)
        del buf65
        buf67 = empty_strided_cuda((50, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_19, silu_2, linear_20, mul_26, hidden_states_13], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf66, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg32_1, (4096, 1024), (1, 4096), 0), out=buf67)
        del arg32_1
        buf69 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf39, buf45, buf61, buf67, arg33_1, buf69, 50, 1024, stream=stream0)
        del arg33_1
        buf70 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf69, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg34_1, (1024, 2048), (1, 1024), 0), out=buf70)
        del arg34_1
        buf71 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf69, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg35_1, (1024, 256), (1, 1024), 0), out=buf71)
        del arg35_1
        buf72 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_9], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf69, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg36_1, (1024, 256), (1, 1024), 0), out=buf72)
        del arg36_1
        buf73 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf70, arg4_1, arg5_1, buf73, 102400, stream=stream0)
        del buf70
        buf74 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf71, arg4_1, arg5_1, buf74, 12800, stream=stream0)
        buf75 = reinterpret_tensor(buf71, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf71  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf72, buf75, 12800, stream=stream0)
        del buf72
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_12, view_9, query_states_13, q_3, chunk_6, key_states_12, view_10, key_states_13, k_3, chunk_7, mul_29, neg_6, cat_7, mul_30, q_embed_3, query_states_14, query_states_15, mul_31, neg_7, cat_8, mul_32, k_embed_3, key_states_14, key_states_15, value_states_9, view_11, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf76 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf73, buf74, buf75, scale=0.08838834764831843)
        del buf74
        del buf75
        buf77 = buf76[0]
        assert_size_stride(buf77, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf77, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf76
        buf82 = reinterpret_tensor(buf73, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf73  # reuse
        # Topologically Sorted Source Nodes: [transpose_15, attn_output_13], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf77, buf82, 102400, stream=stream0)
        del buf77
        buf83 = reinterpret_tensor(buf69, (50, 1024), (1024, 1), 0); del buf69  # reuse
        # Topologically Sorted Source Nodes: [transpose_15, attn_output_13, attn_output_14, attn_output_15], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf82, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg37_1, (2048, 1024), (1, 2048), 0), out=buf83)
        del arg37_1
        del buf82
        buf84 = buf39; del buf39  # reuse
        buf86 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_8, hidden_states_9, attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, attn_output_15, hidden_states_16, to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf84, buf45, buf61, buf67, buf83, arg38_1, buf86, 50, 1024, stream=stream0)
        del arg38_1
        del buf45
        del buf61
        del buf67
        buf87 = reinterpret_tensor(buf66, (50, 4096), (4096, 1), 0); del buf66  # reuse
        # Topologically Sorted Source Nodes: [to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17, linear_26], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf86, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg39_1, (1024, 4096), (1, 1024), 0), out=buf87)
        del arg39_1
        buf88 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_27], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf86, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg40_1, (1024, 4096), (1, 1024), 0), out=buf88)
        del arg40_1
        buf89 = reinterpret_tensor(buf87, (10, 5, 4096), (20480, 4096, 1), 0); del buf87  # reuse
        # Topologically Sorted Source Nodes: [linear_26, silu_3, linear_27, mul_35], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf89, buf88, 204800, stream=stream0)
        del buf88
        buf90 = reinterpret_tensor(buf86, (50, 1024), (1024, 1), 0); del buf86  # reuse
        # Topologically Sorted Source Nodes: [linear_26, silu_3, linear_27, mul_35, hidden_states_18], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf89, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg41_1, (4096, 1024), (1, 4096), 0), out=buf90)
        del arg41_1
        buf92 = reinterpret_tensor(buf83, (10, 5, 1024), (5120, 1024, 1), 0); del buf83  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, to_32, pow_9, variance_8, add_24, rsqrt_8, mul_36, hidden_8, hidden_states_20], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf84, buf90, arg42_1, buf92, 50, 1024, stream=stream0)
        del arg42_1
        buf93 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, to_32, pow_9, variance_8, add_24, rsqrt_8, mul_36, hidden_8, hidden_states_20, query_states_16], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg43_1, (1024, 2048), (1, 1024), 0), out=buf93)
        del arg43_1
        buf94 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_16], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg44_1, (1024, 256), (1, 1024), 0), out=buf94)
        del arg44_1
        buf95 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf92, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg45_1, (1024, 256), (1, 1024), 0), out=buf95)
        del arg45_1
        buf96 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf93, arg4_1, arg5_1, buf96, 102400, stream=stream0)
        del buf93
        buf97 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf94, arg4_1, arg5_1, buf97, 12800, stream=stream0)
        buf98 = reinterpret_tensor(buf94, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf94  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf95, buf98, 12800, stream=stream0)
        del buf95
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_16, view_12, query_states_17, q_4, chunk_8, key_states_16, view_13, key_states_17, k_4, chunk_9, mul_38, neg_8, cat_9, mul_39, q_embed_4, query_states_18, query_states_19, mul_40, neg_9, cat_10, mul_41, k_embed_4, key_states_18, key_states_19, value_states_12, view_14, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf99 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf96, buf97, buf98, scale=0.08838834764831843)
        del buf97
        del buf98
        buf100 = buf99[0]
        assert_size_stride(buf100, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf100, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf99
        buf105 = reinterpret_tensor(buf96, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf96  # reuse
        # Topologically Sorted Source Nodes: [transpose_19, attn_output_17], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf100, buf105, 102400, stream=stream0)
        del buf100
        buf106 = reinterpret_tensor(buf92, (50, 1024), (1024, 1), 0); del buf92  # reuse
        # Topologically Sorted Source Nodes: [transpose_19, attn_output_17, attn_output_18, attn_output_19], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf105, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg46_1, (2048, 1024), (1, 2048), 0), out=buf106)
        del arg46_1
        del buf105
        buf108 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, to_38, pow_10, variance_9, add_28, rsqrt_9, mul_42, hidden_9, hidden_states_22], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf84, buf90, buf106, arg47_1, buf108, 50, 1024, stream=stream0)
        del arg47_1
        buf109 = reinterpret_tensor(buf89, (50, 4096), (4096, 1), 0); del buf89  # reuse
        # Topologically Sorted Source Nodes: [linear_33], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf108, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg48_1, (1024, 4096), (1, 1024), 0), out=buf109)
        del arg48_1
        buf110 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_34], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf108, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg49_1, (1024, 4096), (1, 1024), 0), out=buf110)
        del arg49_1
        del buf108
        buf111 = reinterpret_tensor(buf109, (10, 5, 4096), (20480, 4096, 1), 0); del buf109  # reuse
        # Topologically Sorted Source Nodes: [linear_33, silu_4, linear_34, mul_44], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf111, buf110, 204800, stream=stream0)
        del buf110
        buf112 = empty_strided_cuda((50, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_33, silu_4, linear_34, mul_44, hidden_states_23], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf111, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg50_1, (4096, 1024), (1, 4096), 0), out=buf112)
        del arg50_1
        buf114 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, to_40, pow_11, variance_10, add_30, rsqrt_10, mul_45, hidden_10, hidden_states_25], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf84, buf90, buf106, buf112, arg51_1, buf114, 50, 1024, stream=stream0)
        del arg51_1
        buf115 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf114, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg52_1, (1024, 2048), (1, 1024), 0), out=buf115)
        del arg52_1
        buf116 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf114, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg53_1, (1024, 256), (1, 1024), 0), out=buf116)
        del arg53_1
        buf117 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_15], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf114, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg54_1, (1024, 256), (1, 1024), 0), out=buf117)
        del arg54_1
        buf118 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf115, arg4_1, arg5_1, buf118, 102400, stream=stream0)
        del buf115
        buf119 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf116, arg4_1, arg5_1, buf119, 12800, stream=stream0)
        buf120 = reinterpret_tensor(buf116, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf116  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf117, buf120, 12800, stream=stream0)
        del buf117
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_20, view_15, query_states_21, q_5, chunk_10, key_states_20, view_16, key_states_21, k_5, chunk_11, mul_47, neg_10, cat_11, mul_48, q_embed_5, query_states_22, query_states_23, mul_49, neg_11, cat_12, mul_50, k_embed_5, key_states_22, key_states_23, value_states_15, view_17, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf121 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf118, buf119, buf120, scale=0.08838834764831843)
        del buf119
        del buf120
        buf122 = buf121[0]
        assert_size_stride(buf122, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf122, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf121
        buf127 = reinterpret_tensor(buf118, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf118  # reuse
        # Topologically Sorted Source Nodes: [transpose_23, attn_output_21], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf122, buf127, 102400, stream=stream0)
        del buf122
        buf128 = reinterpret_tensor(buf114, (50, 1024), (1024, 1), 0); del buf114  # reuse
        # Topologically Sorted Source Nodes: [transpose_23, attn_output_21, attn_output_22, attn_output_23], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf127, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg55_1, (2048, 1024), (1, 2048), 0), out=buf128)
        del arg55_1
        del buf127
        buf129 = buf84; del buf84  # reuse
        buf131 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_18, hidden_states_19, attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, attn_output_23, hidden_states_26, to_46, pow_12, variance_11, add_34, rsqrt_11, mul_51, hidden_11, hidden_states_27], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf129, buf90, buf106, buf112, buf128, arg56_1, buf131, 50, 1024, stream=stream0)
        del arg56_1
        del buf106
        del buf112
        del buf128
        buf132 = reinterpret_tensor(buf111, (50, 4096), (4096, 1), 0); del buf111  # reuse
        # Topologically Sorted Source Nodes: [to_46, pow_12, variance_11, add_34, rsqrt_11, mul_51, hidden_11, hidden_states_27, linear_40], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf131, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg57_1, (1024, 4096), (1, 1024), 0), out=buf132)
        del arg57_1
        buf133 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_41], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf131, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg58_1, (1024, 4096), (1, 1024), 0), out=buf133)
        del arg58_1
        buf134 = reinterpret_tensor(buf132, (10, 5, 4096), (20480, 4096, 1), 0); del buf132  # reuse
        # Topologically Sorted Source Nodes: [linear_40, silu_5, linear_41, mul_53], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf134, buf133, 204800, stream=stream0)
        del buf133
        buf135 = reinterpret_tensor(buf131, (50, 1024), (1024, 1), 0); del buf131  # reuse
        # Topologically Sorted Source Nodes: [linear_40, silu_5, linear_41, mul_53, hidden_states_28], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf134, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg59_1, (4096, 1024), (1, 4096), 0), out=buf135)
        del arg59_1
        buf137 = reinterpret_tensor(buf90, (10, 5, 1024), (5120, 1024, 1), 0); del buf90  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, to_48, pow_13, variance_12, add_36, rsqrt_12, mul_54, hidden_12, hidden_states_30], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf129, buf135, arg60_1, buf137, 50, 1024, stream=stream0)
        del arg60_1
        buf138 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, to_48, pow_13, variance_12, add_36, rsqrt_12, mul_54, hidden_12, hidden_states_30, query_states_24], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg61_1, (1024, 2048), (1, 1024), 0), out=buf138)
        del arg61_1
        buf139 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg62_1, (1024, 256), (1, 1024), 0), out=buf139)
        del arg62_1
        buf140 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_18], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf137, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg63_1, (1024, 256), (1, 1024), 0), out=buf140)
        del arg63_1
        buf141 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf138, arg4_1, arg5_1, buf141, 102400, stream=stream0)
        del buf138
        buf142 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf139, arg4_1, arg5_1, buf142, 12800, stream=stream0)
        buf143 = reinterpret_tensor(buf139, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf139  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf140, buf143, 12800, stream=stream0)
        del buf140
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_24, view_18, query_states_25, q_6, chunk_12, key_states_24, view_19, key_states_25, k_6, chunk_13, mul_56, neg_12, cat_13, mul_57, q_embed_6, query_states_26, query_states_27, mul_58, neg_13, cat_14, mul_59, k_embed_6, key_states_26, key_states_27, value_states_18, view_20, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf144 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf141, buf142, buf143, scale=0.08838834764831843)
        del buf142
        del buf143
        buf145 = buf144[0]
        assert_size_stride(buf145, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf145, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf144
        buf150 = reinterpret_tensor(buf141, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf141  # reuse
        # Topologically Sorted Source Nodes: [transpose_27, attn_output_25], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf145, buf150, 102400, stream=stream0)
        del buf145
        buf151 = reinterpret_tensor(buf137, (50, 1024), (1024, 1), 0); del buf137  # reuse
        # Topologically Sorted Source Nodes: [transpose_27, attn_output_25, attn_output_26, attn_output_27], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf150, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg64_1, (2048, 1024), (1, 2048), 0), out=buf151)
        del arg64_1
        del buf150
        buf153 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, to_54, pow_14, variance_13, add_40, rsqrt_13, mul_60, hidden_13, hidden_states_32], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf129, buf135, buf151, arg65_1, buf153, 50, 1024, stream=stream0)
        del arg65_1
        buf154 = reinterpret_tensor(buf134, (50, 4096), (4096, 1), 0); del buf134  # reuse
        # Topologically Sorted Source Nodes: [linear_47], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf153, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg66_1, (1024, 4096), (1, 1024), 0), out=buf154)
        del arg66_1
        buf155 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_48], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf153, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg67_1, (1024, 4096), (1, 1024), 0), out=buf155)
        del arg67_1
        del buf153
        buf156 = reinterpret_tensor(buf154, (10, 5, 4096), (20480, 4096, 1), 0); del buf154  # reuse
        # Topologically Sorted Source Nodes: [linear_47, silu_6, linear_48, mul_62], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf156, buf155, 204800, stream=stream0)
        del buf155
        buf157 = empty_strided_cuda((50, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_47, silu_6, linear_48, mul_62, hidden_states_33], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf156, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg68_1, (4096, 1024), (1, 4096), 0), out=buf157)
        del arg68_1
        buf159 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, to_56, pow_15, variance_14, add_42, rsqrt_14, mul_63, hidden_14, hidden_states_35], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf129, buf135, buf151, buf157, arg69_1, buf159, 50, 1024, stream=stream0)
        del arg69_1
        buf160 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_28], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf159, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg70_1, (1024, 2048), (1, 1024), 0), out=buf160)
        del arg70_1
        buf161 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_28], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf159, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg71_1, (1024, 256), (1, 1024), 0), out=buf161)
        del arg71_1
        buf162 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_21], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf159, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg72_1, (1024, 256), (1, 1024), 0), out=buf162)
        del arg72_1
        buf163 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf160, arg4_1, arg5_1, buf163, 102400, stream=stream0)
        del buf160
        buf164 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf161, arg4_1, arg5_1, buf164, 12800, stream=stream0)
        buf165 = reinterpret_tensor(buf161, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf161  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf162, buf165, 12800, stream=stream0)
        del buf162
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_28, view_21, query_states_29, q_7, chunk_14, key_states_28, view_22, key_states_29, k_7, chunk_15, mul_65, neg_14, cat_15, mul_66, q_embed_7, query_states_30, query_states_31, mul_67, neg_15, cat_16, mul_68, k_embed_7, key_states_30, key_states_31, value_states_21, view_23, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf166 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf163, buf164, buf165, scale=0.08838834764831843)
        del buf164
        del buf165
        buf167 = buf166[0]
        assert_size_stride(buf167, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf167, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf166
        buf172 = reinterpret_tensor(buf163, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf163  # reuse
        # Topologically Sorted Source Nodes: [transpose_31, attn_output_29], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf167, buf172, 102400, stream=stream0)
        del buf167
        buf173 = reinterpret_tensor(buf159, (50, 1024), (1024, 1), 0); del buf159  # reuse
        # Topologically Sorted Source Nodes: [transpose_31, attn_output_29, attn_output_30, attn_output_31], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf172, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg73_1, (2048, 1024), (1, 2048), 0), out=buf173)
        del arg73_1
        del buf172
        buf174 = buf129; del buf129  # reuse
        buf176 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_28, hidden_states_29, attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, attn_output_31, hidden_states_36, to_62, pow_16, variance_15, add_46, rsqrt_15, mul_69, hidden_15, hidden_states_37], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf174, buf135, buf151, buf157, buf173, arg74_1, buf176, 50, 1024, stream=stream0)
        del arg74_1
        del buf135
        del buf151
        del buf157
        buf177 = reinterpret_tensor(buf156, (50, 4096), (4096, 1), 0); del buf156  # reuse
        # Topologically Sorted Source Nodes: [to_62, pow_16, variance_15, add_46, rsqrt_15, mul_69, hidden_15, hidden_states_37, linear_54], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf176, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg75_1, (1024, 4096), (1, 1024), 0), out=buf177)
        del arg75_1
        buf178 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_55], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf176, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg76_1, (1024, 4096), (1, 1024), 0), out=buf178)
        del arg76_1
        buf179 = reinterpret_tensor(buf177, (10, 5, 4096), (20480, 4096, 1), 0); del buf177  # reuse
        # Topologically Sorted Source Nodes: [linear_54, silu_7, linear_55, mul_71], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf179, buf178, 204800, stream=stream0)
        del buf178
        buf180 = reinterpret_tensor(buf176, (50, 1024), (1024, 1), 0); del buf176  # reuse
        # Topologically Sorted Source Nodes: [linear_54, silu_7, linear_55, mul_71, hidden_states_38], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf179, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg77_1, (4096, 1024), (1, 4096), 0), out=buf180)
        del arg77_1
        buf182 = reinterpret_tensor(buf173, (10, 5, 1024), (5120, 1024, 1), 0); del buf173  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, to_64, pow_17, variance_16, add_48, rsqrt_16, mul_72, hidden_16, hidden_states_40], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf174, buf180, arg78_1, buf182, 50, 1024, stream=stream0)
        del arg78_1
        buf183 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, to_64, pow_17, variance_16, add_48, rsqrt_16, mul_72, hidden_16, hidden_states_40, query_states_32], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg79_1, (1024, 2048), (1, 1024), 0), out=buf183)
        del arg79_1
        buf184 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_32], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg80_1, (1024, 256), (1, 1024), 0), out=buf184)
        del arg80_1
        buf185 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf182, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg81_1, (1024, 256), (1, 1024), 0), out=buf185)
        del arg81_1
        buf186 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf183, arg4_1, arg5_1, buf186, 102400, stream=stream0)
        del buf183
        buf187 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf184, arg4_1, arg5_1, buf187, 12800, stream=stream0)
        buf188 = reinterpret_tensor(buf184, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf184  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf185, buf188, 12800, stream=stream0)
        del buf185
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_32, view_24, query_states_33, q_8, chunk_16, key_states_32, view_25, key_states_33, k_8, chunk_17, mul_74, neg_16, cat_17, mul_75, q_embed_8, query_states_34, query_states_35, mul_76, neg_17, cat_18, mul_77, k_embed_8, key_states_34, key_states_35, value_states_24, view_26, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf189 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf186, buf187, buf188, scale=0.08838834764831843)
        del buf187
        del buf188
        buf190 = buf189[0]
        assert_size_stride(buf190, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf190, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf189
        buf195 = reinterpret_tensor(buf186, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf186  # reuse
        # Topologically Sorted Source Nodes: [transpose_35, attn_output_33], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf190, buf195, 102400, stream=stream0)
        del buf190
        buf196 = reinterpret_tensor(buf182, (50, 1024), (1024, 1), 0); del buf182  # reuse
        # Topologically Sorted Source Nodes: [transpose_35, attn_output_33, attn_output_34, attn_output_35], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf195, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg82_1, (2048, 1024), (1, 2048), 0), out=buf196)
        del arg82_1
        del buf195
        buf198 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, to_70, pow_18, variance_17, add_52, rsqrt_17, mul_78, hidden_17, hidden_states_42], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf174, buf180, buf196, arg83_1, buf198, 50, 1024, stream=stream0)
        del arg83_1
        buf199 = reinterpret_tensor(buf179, (50, 4096), (4096, 1), 0); del buf179  # reuse
        # Topologically Sorted Source Nodes: [linear_61], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf198, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg84_1, (1024, 4096), (1, 1024), 0), out=buf199)
        del arg84_1
        buf200 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_62], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf198, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg85_1, (1024, 4096), (1, 1024), 0), out=buf200)
        del arg85_1
        del buf198
        buf201 = reinterpret_tensor(buf199, (10, 5, 4096), (20480, 4096, 1), 0); del buf199  # reuse
        # Topologically Sorted Source Nodes: [linear_61, silu_8, linear_62, mul_80], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf201, buf200, 204800, stream=stream0)
        del buf200
        buf202 = empty_strided_cuda((50, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_61, silu_8, linear_62, mul_80, hidden_states_43], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf201, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg86_1, (4096, 1024), (1, 4096), 0), out=buf202)
        del arg86_1
        buf204 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, to_72, pow_19, variance_18, add_54, rsqrt_18, mul_81, hidden_18, hidden_states_45], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf174, buf180, buf196, buf202, arg87_1, buf204, 50, 1024, stream=stream0)
        del arg87_1
        buf205 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_36], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf204, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg88_1, (1024, 2048), (1, 1024), 0), out=buf205)
        del arg88_1
        buf206 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_36], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf204, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg89_1, (1024, 256), (1, 1024), 0), out=buf206)
        del arg89_1
        buf207 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_27], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf204, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg90_1, (1024, 256), (1, 1024), 0), out=buf207)
        del arg90_1
        buf208 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf205, arg4_1, arg5_1, buf208, 102400, stream=stream0)
        del buf205
        buf209 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf206, arg4_1, arg5_1, buf209, 12800, stream=stream0)
        buf210 = reinterpret_tensor(buf206, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf206  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf207, buf210, 12800, stream=stream0)
        del buf207
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_36, view_27, query_states_37, q_9, chunk_18, key_states_36, view_28, key_states_37, k_9, chunk_19, mul_83, neg_18, cat_19, mul_84, q_embed_9, query_states_38, query_states_39, mul_85, neg_19, cat_20, mul_86, k_embed_9, key_states_38, key_states_39, value_states_27, view_29, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf211 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf208, buf209, buf210, scale=0.08838834764831843)
        del buf209
        del buf210
        buf212 = buf211[0]
        assert_size_stride(buf212, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf212, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf211
        buf217 = reinterpret_tensor(buf208, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf208  # reuse
        # Topologically Sorted Source Nodes: [transpose_39, attn_output_37], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf212, buf217, 102400, stream=stream0)
        del buf212
        buf218 = reinterpret_tensor(buf204, (50, 1024), (1024, 1), 0); del buf204  # reuse
        # Topologically Sorted Source Nodes: [transpose_39, attn_output_37, attn_output_38, attn_output_39], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf217, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg91_1, (2048, 1024), (1, 2048), 0), out=buf218)
        del arg91_1
        del buf217
        buf219 = buf174; del buf174  # reuse
        buf221 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_38, hidden_states_39, attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, attn_output_39, hidden_states_46, to_78, pow_20, variance_19, add_58, rsqrt_19, mul_87, hidden_19, hidden_states_47], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf219, buf180, buf196, buf202, buf218, arg92_1, buf221, 50, 1024, stream=stream0)
        del arg92_1
        del buf180
        del buf196
        del buf202
        buf222 = reinterpret_tensor(buf201, (50, 4096), (4096, 1), 0); del buf201  # reuse
        # Topologically Sorted Source Nodes: [to_78, pow_20, variance_19, add_58, rsqrt_19, mul_87, hidden_19, hidden_states_47, linear_68], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf221, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg93_1, (1024, 4096), (1, 1024), 0), out=buf222)
        del arg93_1
        buf223 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_69], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf221, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg94_1, (1024, 4096), (1, 1024), 0), out=buf223)
        del arg94_1
        buf224 = reinterpret_tensor(buf222, (10, 5, 4096), (20480, 4096, 1), 0); del buf222  # reuse
        # Topologically Sorted Source Nodes: [linear_68, silu_9, linear_69, mul_89], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf224, buf223, 204800, stream=stream0)
        del buf223
        buf225 = reinterpret_tensor(buf221, (50, 1024), (1024, 1), 0); del buf221  # reuse
        # Topologically Sorted Source Nodes: [linear_68, silu_9, linear_69, mul_89, hidden_states_48], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf224, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg95_1, (4096, 1024), (1, 4096), 0), out=buf225)
        del arg95_1
        buf227 = reinterpret_tensor(buf218, (10, 5, 1024), (5120, 1024, 1), 0); del buf218  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, to_80, pow_21, variance_20, add_60, rsqrt_20, mul_90, hidden_20, hidden_states_50], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_9.run(buf219, buf225, arg96_1, buf227, 50, 1024, stream=stream0)
        del arg96_1
        buf228 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, to_80, pow_21, variance_20, add_60, rsqrt_20, mul_90, hidden_20, hidden_states_50, query_states_40], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg97_1, (1024, 2048), (1, 1024), 0), out=buf228)
        del arg97_1
        buf229 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_40], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg98_1, (1024, 256), (1, 1024), 0), out=buf229)
        del arg98_1
        buf230 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_30], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf227, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg99_1, (1024, 256), (1, 1024), 0), out=buf230)
        del arg99_1
        buf231 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf228, arg4_1, arg5_1, buf231, 102400, stream=stream0)
        del buf228
        buf232 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf229, arg4_1, arg5_1, buf232, 12800, stream=stream0)
        buf233 = reinterpret_tensor(buf229, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf229  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf230, buf233, 12800, stream=stream0)
        del buf230
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_40, view_30, query_states_41, q_10, chunk_20, key_states_40, view_31, key_states_41, k_10, chunk_21, mul_92, neg_20, cat_21, mul_93, q_embed_10, query_states_42, query_states_43, mul_94, neg_21, cat_22, mul_95, k_embed_10, key_states_42, key_states_43, value_states_30, view_32, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf234 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf231, buf232, buf233, scale=0.08838834764831843)
        del buf232
        del buf233
        buf235 = buf234[0]
        assert_size_stride(buf235, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf235, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf234
        buf240 = reinterpret_tensor(buf231, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf231  # reuse
        # Topologically Sorted Source Nodes: [transpose_43, attn_output_41], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf235, buf240, 102400, stream=stream0)
        del buf235
        buf241 = reinterpret_tensor(buf227, (50, 1024), (1024, 1), 0); del buf227  # reuse
        # Topologically Sorted Source Nodes: [transpose_43, attn_output_41, attn_output_42, attn_output_43], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf240, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg100_1, (2048, 1024), (1, 2048), 0), out=buf241)
        del arg100_1
        del buf240
        buf243 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, to_86, pow_22, variance_21, add_64, rsqrt_21, mul_96, hidden_21, hidden_states_52], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf219, buf225, buf241, arg101_1, buf243, 50, 1024, stream=stream0)
        del arg101_1
        buf244 = reinterpret_tensor(buf224, (50, 4096), (4096, 1), 0); del buf224  # reuse
        # Topologically Sorted Source Nodes: [linear_75], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf243, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg102_1, (1024, 4096), (1, 1024), 0), out=buf244)
        del arg102_1
        buf245 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_76], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf243, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg103_1, (1024, 4096), (1, 1024), 0), out=buf245)
        del arg103_1
        del buf243
        buf246 = reinterpret_tensor(buf244, (10, 5, 4096), (20480, 4096, 1), 0); del buf244  # reuse
        # Topologically Sorted Source Nodes: [linear_75, silu_10, linear_76, mul_98], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf246, buf245, 204800, stream=stream0)
        del buf245
        buf247 = empty_strided_cuda((50, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_75, silu_10, linear_76, mul_98, hidden_states_53], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf246, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg104_1, (4096, 1024), (1, 4096), 0), out=buf247)
        del arg104_1
        buf249 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, to_88, pow_23, variance_22, add_66, rsqrt_22, mul_99, hidden_22, hidden_states_55], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf219, buf225, buf241, buf247, arg105_1, buf249, 50, 1024, stream=stream0)
        del arg105_1
        buf250 = empty_strided_cuda((50, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_44], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf249, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg106_1, (1024, 2048), (1, 1024), 0), out=buf250)
        del arg106_1
        buf251 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_44], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf249, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg107_1, (1024, 256), (1, 1024), 0), out=buf251)
        del arg107_1
        buf252 = empty_strided_cuda((50, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_33], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf249, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg108_1, (1024, 256), (1, 1024), 0), out=buf252)
        del arg108_1
        buf253 = empty_strided_cuda((10, 16, 5, 128), (10240, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_1.run(buf250, arg4_1, arg5_1, buf253, 102400, stream=stream0)
        del buf250
        buf254 = empty_strided_cuda((10, 2, 5, 128), (1280, 640, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_2.run(buf251, arg4_1, arg5_1, buf254, 12800, stream=stream0)
        del arg4_1
        del arg5_1
        buf255 = reinterpret_tensor(buf251, (10, 2, 5, 128), (1280, 640, 128, 1), 0); del buf251  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_3.run(buf252, buf255, 12800, stream=stream0)
        del buf252
        # Topologically Sorted Source Nodes: [position_ids, cos, sin, query_states_44, view_33, query_states_45, q_11, chunk_22, key_states_44, view_34, key_states_45, k_11, chunk_23, mul_101, neg_22, cat_23, mul_102, q_embed_11, query_states_46, query_states_47, mul_103, neg_23, cat_24, mul_104, k_embed_11, key_states_46, key_states_47, value_states_33, view_35, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf256 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf253, buf254, buf255, scale=0.08838834764831843)
        del buf254
        del buf255
        buf257 = buf256[0]
        assert_size_stride(buf257, (10, 16, 5, 128), (10240, 640, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf257, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf256
        buf262 = reinterpret_tensor(buf253, (10, 5, 16, 128), (10240, 2048, 128, 1), 0); del buf253  # reuse
        # Topologically Sorted Source Nodes: [transpose_47, attn_output_45], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_4.run(buf257, buf262, 102400, stream=stream0)
        del buf257
        buf263 = reinterpret_tensor(buf249, (50, 1024), (1024, 1), 0); del buf249  # reuse
        # Topologically Sorted Source Nodes: [transpose_47, attn_output_45, attn_output_46, attn_output_47], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf262, (50, 2048), (2048, 1), 0), reinterpret_tensor(arg109_1, (2048, 1024), (1, 2048), 0), out=buf263)
        del arg109_1
        del buf262
        buf264 = buf219; del buf219  # reuse
        buf266 = empty_strided_cuda((10, 5, 1024), (5120, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_48, hidden_states_49, attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, attn_output_47, hidden_states_56, to_94, pow_24, variance_23, add_70, rsqrt_23, mul_105, hidden_23, hidden_states_57], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf264, buf225, buf241, buf247, buf263, arg110_1, buf266, 50, 1024, stream=stream0)
        del arg110_1
        del buf225
        del buf241
        del buf247
        del buf263
        buf267 = reinterpret_tensor(buf246, (50, 4096), (4096, 1), 0); del buf246  # reuse
        # Topologically Sorted Source Nodes: [to_94, pow_24, variance_23, add_70, rsqrt_23, mul_105, hidden_23, hidden_states_57, linear_82], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf266, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg111_1, (1024, 4096), (1, 1024), 0), out=buf267)
        del arg111_1
        buf268 = empty_strided_cuda((50, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_83], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf266, (50, 1024), (1024, 1), 0), reinterpret_tensor(arg112_1, (1024, 4096), (1, 1024), 0), out=buf268)
        del arg112_1
        buf269 = reinterpret_tensor(buf267, (10, 5, 4096), (20480, 4096, 1), 0); del buf267  # reuse
        # Topologically Sorted Source Nodes: [linear_82, silu_11, linear_83, mul_107], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_6.run(buf269, buf268, 204800, stream=stream0)
        del buf268
        buf270 = reinterpret_tensor(buf266, (50, 1024), (1024, 1), 0); del buf266  # reuse
        # Topologically Sorted Source Nodes: [linear_82, silu_11, linear_83, mul_107, hidden_states_58], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf269, (50, 4096), (4096, 1), 0), reinterpret_tensor(arg113_1, (4096, 1024), (1, 4096), 0), out=buf270)
        del arg113_1
        del buf269
        buf272 = buf264; del buf264  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_58, hidden_states_59, to_96, pow_25, variance_24, add_72, rsqrt_24, mul_108, hidden_24, hidden_states_60], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_13.run(buf272, buf270, arg114_1, 50, 1024, stream=stream0)
        del arg114_1
        del buf270
    return (buf272, )


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1 = args
        args.clear()
        partition0_args = [arg2_1, arg0_1, arg1_1, arg3_1, arg6_1, arg7_1, arg8_1, arg9_1, arg4_1, arg5_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1]
        del arg2_1, arg0_1, arg1_1, arg3_1, arg6_1, arg7_1, arg8_1, arg9_1, arg4_1, arg5_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1
        (buf272,) = self.partitions[0](partition0_args)
        del partition0_args
        return (reinterpret_tensor(buf272, (1, 10, 1024), (51200, 5120, 1), 0), )

runner = Runner(partitions=[partition_0,])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = rand_strided((1, 10, 4, 64), (2560, 256, 64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg1_1 = rand_strided((1024, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((1, 1, 1, 1024), (1024, 1024, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg4_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg5_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg6_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg7_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg8_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg9_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg10_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg11_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg12_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg13_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg14_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg15_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg16_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg17_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg18_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg19_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg20_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg21_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg22_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg23_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg24_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg25_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg26_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg27_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg28_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg29_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg30_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg31_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg32_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg33_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg34_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg35_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg36_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg37_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg38_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg39_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg40_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg41_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg42_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg43_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg44_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg45_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg46_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg47_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg48_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg49_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg50_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg51_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg52_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg53_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg54_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg55_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg56_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg57_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg58_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg59_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg60_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg61_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg62_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg63_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg64_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg65_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg66_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg67_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg68_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg69_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg70_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg71_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg72_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg73_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg74_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg75_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg76_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg77_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg78_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg79_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg80_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg81_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg82_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg83_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg84_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg85_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg86_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg87_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg88_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg89_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg90_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg91_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg92_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg93_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg94_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg95_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg96_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg97_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg98_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg99_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg100_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg101_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg102_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg103_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg104_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg105_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg106_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg107_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg108_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg109_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg110_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg111_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg112_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg113_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg114_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
