# AOT ID: ['3_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
from torch._C import _cuda_getCurrentRawStream as get_raw_stream



# kernel path: /tmp/torchinductor_shingokuga/ly/cly7md7thmpshkq3dh4ixp35kisdgzvf2kec5norc7twrfgulusd.py
# Topologically Sorted Source Nodes: [to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add => add
#   hidden => convert_element_type_1
#   hidden_states => mul_1
#   mul => mul
#   pow_1 => pow_1
#   rsqrt => rsqrt
#   to => convert_element_type
#   variance => mean
# Graph fragment:
#   %arg5_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %buf0 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf0]
#   %arg4_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg4_1]
#   %convert_element_type : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg5_1, torch.float32), kwargs = {})
#   %pow_1 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {})
#   %mean : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
#   %add : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
#   %rsqrt : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
#   %mul : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg5_1, %rsqrt), kwargs = {})
#   %convert_element_type_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
#   %mul_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %arg4_1), kwargs = {})
#   return %buf0,%mul_1
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 16384}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tmp0.to(tl.float32)
        tmp2 = tmp1 * tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp6 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tmp6.to(tl.float32)
        tmp8 = tl.full([1, 1], 2048.0, tl.float32)
        tmp9 = (tmp4 / tmp8)
        tmp10 = tl.full([1, 1], 1e-05, tl.float32)
        tmp11 = tmp9 + tmp10
        tmp12 = libdevice.rsqrt(tmp11)
        tmp13 = tmp7 * tmp12
        tmp14 = tmp13.to(tl.float32)
        tmp16 = tmp14 * tmp15
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp16, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/uv/cuvo6e5a4aywzrqf6tk2ykcqmb7v5qpw7ab5ihyhfij4je62n7ai.py
# Topologically Sorted Source Nodes: [view_1, key_states_1, k, chunk_1, cos, sin, key_cache, mul_4, neg_1, cat_1, mul_5, k_embed, key_states_2, setitem], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.split, aten.index, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_1 => cat_1
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_9
#   k_embed => add_2
#   key_cache => select, select_1
#   key_states_1 => permute_4
#   key_states_2 => convert_element_type_11
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg_1 => neg_1
#   setitem => index_put, view_3
#   sin => index_1
#   view_1 => view_1
# Graph fragment:
#   %copy_ : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %view_1 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 1, 2, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_1, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_9 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_9, 64, -1), kwargs = {})
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %select : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg3_1, 0, 0), kwargs = {})
#   %select_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select, 0, 0), kwargs = {})
#   %mul_4 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_9, %index), kwargs = {})
#   %neg_1 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_1 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_11 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %view_3 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_11, [2, 1, 128]), kwargs = {})
#   %index_put : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_1, [None, None, %arg1_1], %view_3), kwargs = {})
#   return %index_put
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_1 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zt/cztfimy6up64dqqhnjm2eznav2h3hehwoevglggaentzkaqj7zhw.py
# Topologically Sorted Source Nodes: [view_1, key_states_1, k, chunk_1, cos, sin, key_cache, mul_4, neg_1, cat_1, mul_5, k_embed, key_states_2, setitem], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.split, aten.index, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_1 => cat_1
#   chunk_1 => split_1
#   cos => index
#   k => convert_element_type_9
#   k_embed => add_2
#   key_cache => select, select_1
#   key_states_1 => permute_4
#   key_states_2 => convert_element_type_11
#   mul_4 => mul_4
#   mul_5 => mul_5
#   neg_1 => neg_1
#   setitem => index_put, view_3
#   sin => index_1
#   view_1 => view_1
# Graph fragment:
#   %arg1_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg1_1]
#   %mm_1 : Tensor "bf16[1, 256][256, 1]cuda:0" = PlaceHolder[target=mm_1]
#   %arg0_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg0_1]
#   %arg2_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg2_1]
#   %index_put : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=index_put]
#   %view_1 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 1, 2, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_1, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_9 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_4, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_9, 64, -1), kwargs = {})
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %select : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg3_1, 0, 0), kwargs = {})
#   %select_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select, 0, 0), kwargs = {})
#   %mul_4 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_9, %index), kwargs = {})
#   %neg_1 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_1 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %getitem_2], -1), kwargs = {})
#   %mul_5 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %index_1), kwargs = {})
#   %add_2 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
#   %convert_element_type_11 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_2, torch.bfloat16), kwargs = {})
#   %view_3 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_11, [2, 1, 128]), kwargs = {})
#   %index_put : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_1, [None, None, %arg1_1], %view_3), kwargs = {})
#   return %buf5
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 256}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 1536}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = (xindex % 128)
    x1 = xindex // 128
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp7 = tl.load(in_ptr1 + (x2), xmask).to(tl.float32)
    tmp2 = tl.full([XBLOCK], 8192, tl.int32)
    tmp3 = tmp1 + tmp2
    tmp4 = tmp1 < 0
    tmp5 = tl.where(tmp4, tmp3, tmp1)
    tl.device_assert((0 <= tmp5) & (tmp5 < 8192), "index out of bounds: 0 <= tmp5 < 8192")
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tl.full([XBLOCK], 32768, tl.int32)
    tmp10 = tmp1 + tmp9
    tmp11 = tl.where(tmp4, tmp10, tmp1)
    tl.device_assert((0 <= tmp11) & (tmp11 < 32768), "index out of bounds: 0 <= tmp11 < 32768")
    tmp13 = tl.load(in_ptr2 + (x0 + 128*tmp11), xmask).to(tl.float32)
    tmp14 = tmp13.to(tl.float32)
    tmp15 = tmp8 * tmp14
    tmp16 = x0
    tmp17 = tl.full([1], 0, tl.int64)
    tmp18 = tmp16 >= tmp17
    tmp19 = tl.full([1], 64, tl.int64)
    tmp20 = tmp16 < tmp19
    tmp21 = tl.load(in_ptr1 + (64 + 128*x1 + (x0)), tmp20 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp22 = tmp21.to(tl.float32)
    tmp23 = -tmp22
    tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
    tmp25 = tl.where(tmp20, tmp23, tmp24)
    tmp26 = tmp16 >= tmp19
    tmp27 = tl.full([1], 128, tl.int64)
    tmp28 = tmp16 < tmp27
    tmp29 = tl.load(in_ptr1 + (128*x1 + ((-64) + x0)), tmp26 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp30 = tmp29.to(tl.float32)
    tmp31 = tl.full(tmp30.shape, 0.0, tmp30.dtype)
    tmp32 = tl.where(tmp26, tmp30, tmp31)
    tmp33 = tl.where(tmp20, tmp25, tmp32)
    tmp34 = tl.load(in_ptr3 + (x0 + 128*tmp11), xmask).to(tl.float32)
    tmp35 = tmp34.to(tl.float32)
    tmp36 = tmp33 * tmp35
    tmp37 = tmp15 + tmp36
    tmp38 = tmp37.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*tmp5 + 1048576*x1), tmp38, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zx/czx6baaxxu5jwv5kfwa44ehsekrju7btwmlcxrdqxengpngi77dt.py
# Topologically Sorted Source Nodes: [setitem_1, view_2, value_states_1], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_1 => index_put_1, select_7, select_8, view_4
#   value_states_1 => permute_5
#   view_2 => view_2
# Graph fragment:
#   %buf5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf5]
#   %copy_ : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %select_int : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg3_1, 0, 0), kwargs = {})
#   %select_scatter_default : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int, %index_put, 0, 0), kwargs = {})
#   %select_scatter_default_1 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%arg3_1, %select_scatter_default, 0, 0), kwargs = {})
#   %select_7 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_1, 0, 1), kwargs = {})
#   %select_8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_7, 0, 0), kwargs = {})
#   %view_2 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [1, 1, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_2, [0, 2, 1, 3]), kwargs = {})
#   %view_4 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_5, [2, 1, 128]), kwargs = {})
#   %index_put_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_8, [None, None, %arg1_1], %view_4), kwargs = {})
#   return %index_put_1
triton_poi_fused_index_put_select_transpose_view_3 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_3', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_3(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp4 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp5 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
    tmp7 = tl.load(in_ptr1 + (58720256 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tmp1 == tmp1
    tmp6 = tl.where(tmp3, tmp4, tmp5)
    tmp8 = tl.where(tmp2, tmp6, tmp7)
    tl.store(out_ptr0 + (x0), tmp8, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/e5/ce5jvl5qcyxbktbf65wgngyqry3y6nafwgz6ooalliscjsg5nf23.py
# Topologically Sorted Source Nodes: [setitem_1, view_2, value_states_1], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_1 => index_put_1, select_7, select_8, view_4
#   value_states_1 => permute_5
#   view_2 => view_2
# Graph fragment:
#   %arg1_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg1_1]
#   %mm_2 : Tensor "bf16[1, 256][256, 1]cuda:0" = PlaceHolder[target=mm_2]
#   %index_put_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=index_put_1]
#   %select_int : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg3_1, 0, 0), kwargs = {})
#   %select_scatter_default : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int, %index_put, 0, 0), kwargs = {})
#   %select_scatter_default_1 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%arg3_1, %select_scatter_default, 0, 0), kwargs = {})
#   %select_7 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_1, 0, 1), kwargs = {})
#   %select_8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_7, 0, 0), kwargs = {})
#   %view_2 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [1, 1, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_2, [0, 2, 1, 3]), kwargs = {})
#   %view_4 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_5, [2, 1, 128]), kwargs = {})
#   %index_put_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_8, [None, None, %arg1_1], %view_4), kwargs = {})
#   return %buf8
triton_poi_fused_index_put_select_transpose_view_4 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_4', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 256}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_4', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 512}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_4(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = (xindex % 128)
    x1 = xindex // 128
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp7 = tl.load(in_ptr1 + (x2), xmask).to(tl.float32)
    tmp2 = tl.full([XBLOCK], 8192, tl.int32)
    tmp3 = tmp1 + tmp2
    tmp4 = tmp1 < 0
    tmp5 = tl.where(tmp4, tmp3, tmp1)
    tl.device_assert((0 <= tmp5) & (tmp5 < 8192), "index out of bounds: 0 <= tmp5 < 8192")
    tl.store(out_ptr0 + (x0 + 128*tmp5 + 1048576*x1), tmp7, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/wo/cwooowltfxjc2qndh5h7dtsrkhxp6zd5cjsgqvhpt3kedtqah4oe.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf8]
#   %buf5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf5]
#   %copy_ : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %select_int : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg3_1, 0, 0), kwargs = {})
#   %select_scatter_default : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int, %index_put, 0, 0), kwargs = {})
#   %select_scatter_default_1 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%arg3_1, %select_scatter_default, 0, 0), kwargs = {})
#   %select_int_1 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_1, 0, 1), kwargs = {})
#   %select_scatter_default_2 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_1, %index_put_1, 0, 0), kwargs = {})
#   %select_scatter_default_3 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_1, %select_scatter_default_2, 0, 1), kwargs = {})
#   return %select_scatter_default_3
triton_poi_fused_5 = async_compile.triton('triton_poi_fused_5', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_5(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp11 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp15 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 0, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tmp1 == tmp4
    tmp10 = tl.where(tmp5, tmp8, tmp9)
    tmp12 = tl.where(tmp7, tmp10, tmp11)
    tmp13 = tl.where(tmp5, tmp6, tmp12)
    tmp14 = tmp0 == tmp4
    tmp16 = tl.where(tmp14, tmp10, tmp15)
    tmp17 = tl.where(tmp2, tmp13, tmp16)
    tl.store(out_ptr0 + (x4), tmp17, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/p5/cp5cly5o3vis4qkzwtsu3s74ufumuyunqzocfj7aodsk6556vvfw.py
# Topologically Sorted Source Nodes: [view, query_states_1, q, chunk, cos, mul_2, neg, cat, sin, mul_3, q_embed, attn_output], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.split, aten.index, aten.mul, aten.neg, aten.cat, aten.add]
# Source node to ATen node mapping:
#   attn_output => mul_6
#   cat => cat
#   chunk => split
#   cos => index
#   mul_2 => mul_2
#   mul_3 => mul_3
#   neg => neg
#   q => convert_element_type_8
#   q_embed => add_1
#   query_states_1 => permute_3
#   sin => index_1
#   view => view
# Graph fragment:
#   %mm : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm]
#   %arg1_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg1_1]
#   %arg0_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg0_1]
#   %arg2_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg2_1]
#   %view : Tensor "bf16[1, 1, 16, 128][2048, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [1, 1, 16, 128]), kwargs = {})
#   %permute_3 : Tensor "bf16[1, 16, 1, 128][2048, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_8 : Tensor "f32[1, 16, 1, 128][2048, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_3, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_8, 64, -1), kwargs = {})
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %mul_2 : Tensor "f32[1, 16, 1, 128][2048, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_8, %index), kwargs = {})
#   %neg : Tensor "f32[1, 16, 1, 64][1024, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat : Tensor "f32[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %mul_3 : Tensor "f32[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat, %index_1), kwargs = {})
#   %add_1 : Tensor "f32[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_2, %mul_3), kwargs = {})
#   %mul_6 : Tensor "f32[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%add_1, 0.29730177875068026), kwargs = {})
#   return %expand_2
triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 28672}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = (xindex % 128)
    x1 = xindex // 128
    tmp0 = tl.load(in_ptr0 + (x2), xmask).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (0))
    tmp3 = tl.broadcast_to(tmp2, [XBLOCK])
    tmp1 = tmp0.to(tl.float32)
    tmp4 = tl.full([XBLOCK], 32768, tl.int32)
    tmp5 = tmp3 + tmp4
    tmp6 = tmp3 < 0
    tmp7 = tl.where(tmp6, tmp5, tmp3)
    tl.device_assert((0 <= tmp7) & (tmp7 < 32768), "index out of bounds: 0 <= tmp7 < 32768")
    tmp9 = tl.load(in_ptr2 + (x0 + 128*tmp7), xmask).to(tl.float32)
    tmp10 = tmp9.to(tl.float32)
    tmp11 = tmp1 * tmp10
    tmp12 = x0
    tmp13 = tl.full([1], 0, tl.int64)
    tmp14 = tmp12 >= tmp13
    tmp15 = tl.full([1], 64, tl.int64)
    tmp16 = tmp12 < tmp15
    tmp17 = tl.load(in_ptr0 + (64 + 128*x1 + (x0)), tmp16 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp18 = tmp17.to(tl.float32)
    tmp19 = -tmp18
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp16, tmp19, tmp20)
    tmp22 = tmp12 >= tmp15
    tmp23 = tl.full([1], 128, tl.int64)
    tmp24 = tmp12 < tmp23
    tmp25 = tl.load(in_ptr0 + (128*x1 + ((-64) + x0)), tmp22 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp26 = tmp25.to(tl.float32)
    tmp27 = tl.full(tmp26.shape, 0.0, tmp26.dtype)
    tmp28 = tl.where(tmp22, tmp26, tmp27)
    tmp29 = tl.where(tmp16, tmp21, tmp28)
    tmp30 = tl.load(in_ptr3 + (x0 + 128*tmp7), xmask).to(tl.float32)
    tmp31 = tmp30.to(tl.float32)
    tmp32 = tmp29 * tmp31
    tmp33 = tmp11 + tmp32
    tmp34 = tl.full([1], 0.29730177875068026, tl.float32)
    tmp35 = tmp33 * tmp34
    tl.store(out_ptr0 + (x2), tmp35, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ug/cugas2keyp2kryezst47i4w2acrfm7ojvhfwg3uaxjm4lysbx46x.py
# Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output => clone, convert_element_type_13, expand, mul_7, permute_6, select_12, select_13, unsqueeze, view_5
# Graph fragment:
#   %select_scatter_default_3 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_12 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_13 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_12, 0, 0), kwargs = {})
#   %convert_element_type_13 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_13, torch.float32), kwargs = {})
#   %unsqueeze : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_13, 2), kwargs = {})
#   %expand : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand,), kwargs = {memory_format: torch.contiguous_format})
#   %view_5 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone, [1, 16, 8192, 128]), kwargs = {})
#   %permute_6 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_5, [0, 1, 3, 2]), kwargs = {})
#   %mul_7 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_6, 0.29730177875068026), kwargs = {})
#   return %expand_3
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_7 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_7', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_7(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/nv/cnv4xi2q5oam5kme44wqsvfehqqw2upbp33nxuat3fni7ogplyvf.py
# Topologically Sorted Source Nodes: [attn_output, arange, attn_mask], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
# Source node to ATen node mapping:
#   arange => iota
#   attn_mask => le
#   attn_output => add_3, any_1, div, eq, full_default, full_default_1, full_default_2, logical_not, logical_not_1, view_9, where, where_1
# Graph fragment:
#   %bmm : Tensor "f32[16, 1, 8192][8192, 8192, 1]cuda:0" = PlaceHolder[target=bmm]
#   %arg1_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg1_1]
#   %any_1 : Tensor "b8[1, 16, 1, 1][16, 1, 16, 16]cuda:0" = PlaceHolder[target=any_1]
#   %getitem_166 : Tensor "f32[1, 16, 1, 1][16, 1, 16, 16]cuda:0" = PlaceHolder[target=getitem_166]
#   %getitem_167 : Tensor "f32[1, 16, 1, 1][16, 1, 16, 16]cuda:0" = PlaceHolder[target=getitem_167]
#   %view_9 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [1, 16, 1, 8192]), kwargs = {})
#   %iota : Tensor "i64[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8192,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %le : Tensor "b8[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%iota, %arg1_1), kwargs = {})
#   %full_default_1 : Tensor "bf16[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %full_default : Tensor "bf16[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], -inf), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %where : Tensor "bf16[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le, %full_default_1, %full_default), kwargs = {})
#   %add_3 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_9, %where), kwargs = {})
#   %eq : Tensor "b8[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%add_3, -inf), kwargs = {})
#   %logical_not : Tensor "b8[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.logical_not.default](args = (%eq,), kwargs = {})
#   %any_1 : Tensor "b8[1, 16, 1, 1][16, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.any.dim](args = (%logical_not, -1, True), kwargs = {})
#   %logical_not_1 : Tensor "b8[1, 16, 1, 1][16, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.logical_not.default](args = (%any_1,), kwargs = {})
#   %full_default_2 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([1, 16, 1, 8192], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %prepare_softmax_online_default_27 : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%add_3, -1), kwargs = {})
#   %sub_tensor_27 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_3, %getitem_166), kwargs = {})
#   %exp_default_27 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor_27,), kwargs = {})
#   %div : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default_27, %getitem_167), kwargs = {})
#   %where_1 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_not_1, %full_default_2, %div), kwargs = {})
#   return %any_1,%getitem_166,%getitem_167,%expand_4
triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8 = async_compile.triton('triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 16, 'r0_': 8192},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 3, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1572864}}
)
@triton.jit
def triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8(in_out_ptr0, in_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 16
    r0_numel = 8192
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = xindex
    tmp1 = tl.load(in_ptr0 + (0))
    tmp2 = tl.broadcast_to(tmp1, [1, 1])
    _tmp15 = tl.full([XBLOCK, R0_BLOCK], False, tl.int1)
    _tmp19_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
    _tmp19_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp0 = tl.load(in_out_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
        tmp3 = r0_1
        tmp4 = tmp3 <= tmp2
        tmp5 = tl.full([1, 1], 0.0, tl.float32)
        tmp6 = tl.full([1, 1], float("-inf"), tl.float32)
        tmp7 = tl.where(tmp4, tmp5, tmp6)
        tmp8 = tmp7.to(tl.float32)
        tmp9 = tmp0 + tmp8
        tmp10 = tmp9 == tmp6
        tmp11 = tmp10 == 0
        tmp12 = tmp11.to(tl.int64)
        tmp13 = (tmp12 != 0)
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
        tmp16 = _tmp15 | tmp14
        _tmp15 = tl.where(r0_mask & xmask, tmp16, _tmp15)
        tmp18 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])

        _tmp19_max_next, _tmp19_sum_next = triton_helpers.online_softmax_combine(
            _tmp19_max, _tmp19_sum, tmp18, False
        )

        _tmp19_max = tl.where(r0_mask & xmask, _tmp19_max_next, _tmp19_max)
        _tmp19_sum = tl.where(r0_mask & xmask, _tmp19_sum_next, _tmp19_sum)
    tmp17 = _tmp15.to(tl.int8)
    tmp15 = triton_helpers.any(tmp17, 1)[:, None]

    tmp19, tmp20 = triton_helpers.online_softmax_reduce(
        _tmp19_max, _tmp19_sum, 1, False)
    tmp19 = tmp19[:, None]
    tmp20 = tmp20[:, None]
    tmp23 = tl.load(in_ptr0 + (0))
    tmp24 = tl.broadcast_to(tmp23, [1, 1])
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp22 = tl.load(in_out_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
        tmp21 = tmp15 == 0
        tmp25 = r0_1
        tmp26 = tmp25 <= tmp24
        tmp27 = tl.full([1, 1], 0.0, tl.float32)
        tmp28 = tl.full([1, 1], float("-inf"), tl.float32)
        tmp29 = tl.where(tmp26, tmp27, tmp28)
        tmp30 = tmp29.to(tl.float32)
        tmp31 = tmp22 + tmp30
        tmp32 = tmp31 - tmp19
        tmp33 = libdevice.exp(tmp32)
        tmp34 = (tmp33 / tmp20)
        tmp35 = tl.where(tmp21, tmp27, tmp34)
        tl.store(in_out_ptr0 + (r0_1 + 8192*x0), tmp35, r0_mask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/m2/cm27uuettv4alfoi6qgjydelkog665mytfcx7cpx5ooi6nwtpfpy.py
# Topologically Sorted Source Nodes: [setitem_1, attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output => clone_1, convert_element_type_14, expand_1, unsqueeze_1
#   setitem_1 => select_10, select_11
# Graph fragment:
#   %select_scatter_default_3 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_10 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 1), kwargs = {})
#   %select_11 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_10, 0, 0), kwargs = {})
#   %convert_element_type_14 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_11, torch.float32), kwargs = {})
#   %unsqueeze_1 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_14, 2), kwargs = {})
#   %expand_1 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_1, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_1 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_1,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_1
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_9 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_9', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_9(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (58720256 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/jp/cjpurwo7lpuzatuozafcu255bsabnidb4gdszw7j5ojydzh7e4ej.py
# Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.view, aten._to_copy]
# Source node to ATen node mapping:
#   attn_output => convert_element_type_15, view_12
# Graph fragment:
#   %bmm_1 : Tensor "f32[16, 1, 128][128, 128, 1]cuda:0" = PlaceHolder[target=bmm_1]
#   %view_12 : Tensor "f32[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [1, 16, 1, 128]), kwargs = {})
#   %convert_element_type_15 : Tensor "bf16[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_12, torch.bfloat16), kwargs = {})
#   return %convert_element_type_15
triton_poi_fused__to_copy_view_10 = async_compile.triton('triton_poi_fused__to_copy_view_10', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_view_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16384}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_view_10(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qb/cqbdysuimpqahd7onuyx776erjbu3jbg2mxbkntalbmdxf3btxoo.py
# Topologically Sorted Source Nodes: [hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_4 => add_5
#   hidden_1 => convert_element_type_20
#   hidden_states_1 => add_4
#   hidden_states_2 => mul_9
#   mul_6 => mul_8
#   pow_2 => pow_2
#   rsqrt_1 => rsqrt_1
#   to_6 => convert_element_type_19
#   variance_1 => mean_1
# Graph fragment:
#   %arg5_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %buf21 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf21]
#   %arg10_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg10_1]
#   %add_4 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg5_1, %mm_3), kwargs = {})
#   %convert_element_type_19 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_4, torch.float32), kwargs = {})
#   %pow_2 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_19, 2), kwargs = {})
#   %mean_1 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
#   %add_5 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
#   %rsqrt_1 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_5,), kwargs = {})
#   %mul_8 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_4, %rsqrt_1), kwargs = {})
#   %convert_element_type_20 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_8, torch.bfloat16), kwargs = {})
#   %mul_9 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_20, %arg10_1), kwargs = {})
#   return %buf21,%mul_9
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 20480}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp6 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp3 * tmp3
        tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
        tmp7 = _tmp6 + tmp5
        _tmp6 = tl.where(r0_mask, tmp7, _tmp6)
    tmp6 = tl.sum(_tmp6, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp8 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp9 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp19 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp10 = tmp8 + tmp9
        tmp11 = tmp10.to(tl.float32)
        tmp12 = tl.full([1, 1], 2048.0, tl.float32)
        tmp13 = (tmp6 / tmp12)
        tmp14 = tl.full([1, 1], 1e-05, tl.float32)
        tmp15 = tmp13 + tmp14
        tmp16 = libdevice.rsqrt(tmp15)
        tmp17 = tmp11 * tmp16
        tmp18 = tmp17.to(tl.float32)
        tmp20 = tmp18 * tmp19
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp20, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5q/c5qhod42xzt5enox4t3db6ulbfw5zxvrwgdkwzimhjzggw7u6ffj.py
# Topologically Sorted Source Nodes: [silu, mul_8], Original ATen: [aten.silu, aten.mul]
# Source node to ATen node mapping:
#   mul_8 => mul_10
#   silu => add_6, convert_element_type_23, convert_element_type_24, div_1, exp_1, neg_2
# Graph fragment:
#   %mm_4 : Tensor "bf16[1, 6144][6144, 1]cuda:0" = PlaceHolder[target=mm_4]
#   %mm_5 : Tensor "bf16[1, 6144][6144, 1]cuda:0" = PlaceHolder[target=mm_5]
#   %convert_element_type_23 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_4, torch.float32), kwargs = {})
#   %neg_2 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%convert_element_type_23,), kwargs = {})
#   %exp_1 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%neg_2,), kwargs = {})
#   %add_6 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%exp_1, 1), kwargs = {})
#   %div_1 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_23, %add_6), kwargs = {})
#   %convert_element_type_24 : Tensor "bf16[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%div_1, torch.bfloat16), kwargs = {})
#   %mul_10 : Tensor "bf16[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_24, %mm_5), kwargs = {})
#   return %mul_10
triton_poi_fused_mul_silu_12 = async_compile.triton('triton_poi_fused_mul_silu_12', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 8192}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_12', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 49152}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_silu_12(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6144
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
    tmp8 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = -tmp1
    tmp3 = libdevice.exp(tmp2)
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = (tmp1 / tmp5)
    tmp7 = tmp6.to(tl.float32)
    tmp9 = tmp7 * tmp8
    tl.store(in_out_ptr0 + (x0), tmp9, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/sl/csl7gh6onpu2bf4yqtlc5qh5bdj6icwhb6w3tyeixarkfga2h63v.py
# Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_6 => add_8
#   hidden_2 => convert_element_type_30
#   hidden_states_1 => add_4
#   hidden_states_4 => add_7
#   hidden_states_5 => mul_12
#   mul_9 => mul_11
#   pow_3 => pow_3
#   rsqrt_2 => rsqrt_2
#   to_8 => convert_element_type_29
#   variance_2 => mean_2
# Graph fragment:
#   %arg5_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %buf27 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf27]
#   %arg14_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg14_1]
#   %add_4 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg5_1, %mm_3), kwargs = {})
#   %add_7 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %mm_6), kwargs = {})
#   %convert_element_type_29 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.float32), kwargs = {})
#   %pow_3 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_29, 2), kwargs = {})
#   %mean_2 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
#   %add_8 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
#   %rsqrt_2 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_8,), kwargs = {})
#   %mul_11 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_7, %rsqrt_2), kwargs = {})
#   %convert_element_type_30 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_11, torch.bfloat16), kwargs = {})
#   %mul_12 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_30, %arg14_1), kwargs = {})
#   return %buf27,%mul_12
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 24576}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp5 = tmp4.to(tl.float32)
        tmp6 = tmp5 * tmp5
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
        tmp9 = _tmp8 + tmp7
        _tmp8 = tl.where(r0_mask, tmp9, _tmp8)
    tmp8 = tl.sum(_tmp8, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp10 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp11 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp12 = tmp10 + tmp11
        tmp14 = tmp12 + tmp13
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp8 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qb/cqbab4swzwbfroi44p4grlpsm4sm6ccioqfzgdouip5x5ntq3fbx.py
# Topologically Sorted Source Nodes: [cos, sin, view_4, key_states_4, k_1, chunk_3, setitem_2, mul_13, neg_3, cat_3, mul_14, k_embed_1, key_states_5], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_3 => cat_3
#   chunk_3 => split_3
#   cos => index
#   k_1 => convert_element_type_38
#   k_embed_1 => add_10
#   key_states_4 => permute_16
#   key_states_5 => convert_element_type_40
#   mul_13 => mul_15
#   mul_14 => mul_16
#   neg_3 => neg_4
#   setitem_2 => index_put_2, select_20, select_21, view_17
#   sin => index_1
#   view_4 => view_15
# Graph fragment:
#   %select_scatter_default_3 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_15 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_8, [1, 1, 2, 128]), kwargs = {})
#   %permute_16 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_15, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_38 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_16, torch.float32), kwargs = {})
#   %split_3 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_38, 64, -1), kwargs = {})
#   %select_20 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_21 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_20, 0, 1), kwargs = {})
#   %mul_15 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_38, %index), kwargs = {})
#   %neg_4 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_7,), kwargs = {})
#   %cat_3 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_4, %getitem_6], -1), kwargs = {})
#   %mul_16 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_3, %index_1), kwargs = {})
#   %add_10 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_15, %mul_16), kwargs = {})
#   %convert_element_type_40 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_10, torch.bfloat16), kwargs = {})
#   %view_17 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_40, [2, 1, 128]), kwargs = {})
#   %index_put_2 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_21, [None, None, %arg1_1], %view_17), kwargs = {})
#   return %index_put_2
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_14 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_14', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_14', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_14(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (2097152 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ut/cut6a2meiv5ogcslrionc4qe7bcb2fhqlxewjnivhylyhqqeblky.py
# Topologically Sorted Source Nodes: [setitem_3, view_5, value_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_3 => index_put_3, select_25, select_26, view_18
#   value_states_3 => permute_17
#   view_5 => view_16
# Graph fragment:
#   %buf32 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf32]
#   %select_scatter_default_3 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_int_2 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_scatter_default_4 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_2, %index_put_2, 0, 1), kwargs = {})
#   %select_scatter_default_5 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_3, %select_scatter_default_4, 0, 0), kwargs = {})
#   %select_25 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_5, 0, 1), kwargs = {})
#   %select_26 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_25, 0, 1), kwargs = {})
#   %view_16 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_9, [1, 1, 2, 128]), kwargs = {})
#   %permute_17 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_16, [0, 2, 1, 3]), kwargs = {})
#   %view_18 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_17, [2, 1, 128]), kwargs = {})
#   %index_put_3 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_26, [None, None, %arg1_1], %view_18), kwargs = {})
#   return %index_put_3
triton_poi_fused_index_put_select_transpose_view_15 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_15', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_15(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp4 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp5 = tl.load(in_ptr1 + (2097152 + x0), None).to(tl.float32)
    tmp7 = tl.load(in_ptr1 + (60817408 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tmp0 == tmp0
    tmp6 = tl.where(tmp3, tmp4, tmp5)
    tmp8 = tl.where(tmp2, tmp6, tmp7)
    tl.store(out_ptr0 + (x0), tmp8, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/es/cesjsjqj4bh2ncy5ylt4bfomeipp2ihlvqrbi5i4gtritl4locio.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf35 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf35]
#   %buf32 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf32]
#   %select_scatter_default_3 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_int_2 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_scatter_default_4 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_2, %index_put_2, 0, 1), kwargs = {})
#   %select_scatter_default_5 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_3, %select_scatter_default_4, 0, 0), kwargs = {})
#   %select_int_3 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_5, 0, 1), kwargs = {})
#   %select_scatter_default_6 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_3, %index_put_3, 0, 1), kwargs = {})
#   %select_scatter_default_7 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_5, %select_scatter_default_6, 0, 1), kwargs = {})
#   return %select_scatter_default_7
triton_poi_fused_16 = async_compile.triton('triton_poi_fused_16', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_16(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp11 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp15 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tmp3 == tmp1
    tmp6 = tl.full([1], 0, tl.int32)
    tmp7 = tmp1 == tmp6
    tmp10 = tl.where(tmp4, tmp8, tmp9)
    tmp12 = tl.where(tmp7, tmp10, tmp11)
    tmp13 = tl.where(tmp4, tmp5, tmp12)
    tmp14 = tmp0 == tmp6
    tmp16 = tl.where(tmp14, tmp10, tmp15)
    tmp17 = tl.where(tmp2, tmp13, tmp16)
    tl.store(out_ptr0 + (x4), tmp17, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vt/cvt5wwa4nppmb35unalmzoq3fj5aqqlvskje47fay7gxpr3bfubb.py
# Topologically Sorted Source Nodes: [attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_4 => clone_2, convert_element_type_42, expand_6, mul_18, permute_18, select_30, select_31, unsqueeze_2, view_19
# Graph fragment:
#   %select_scatter_default_7 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_30 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_31 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_30, 0, 1), kwargs = {})
#   %convert_element_type_42 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_31, torch.float32), kwargs = {})
#   %unsqueeze_2 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_42, 2), kwargs = {})
#   %expand_6 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_2, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_2 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_6,), kwargs = {memory_format: torch.contiguous_format})
#   %view_19 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_2, [1, 16, 8192, 128]), kwargs = {})
#   %permute_18 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_19, [0, 1, 3, 2]), kwargs = {})
#   %mul_18 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_18, 0.29730177875068026), kwargs = {})
#   return %expand_9
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_17 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_17', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_17(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (2097152 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qd/cqdf2rsrszqji3lj7n2pqnxno4ltmised3qliiqfnurwrxcxbcpx.py
# Topologically Sorted Source Nodes: [setitem_3, attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_4 => clone_3, convert_element_type_43, expand_7, unsqueeze_3
#   setitem_3 => select_28, select_29
# Graph fragment:
#   %select_scatter_default_7 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_28 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 1), kwargs = {})
#   %select_29 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_28, 0, 1), kwargs = {})
#   %convert_element_type_43 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_29, torch.float32), kwargs = {})
#   %unsqueeze_3 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_43, 2), kwargs = {})
#   %expand_7 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_3, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_3 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_7,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_3
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_18 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_18', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_18', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_18(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (60817408 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ul/culkq7jq6ky456hvko5kyfjpvreyrgfo46l2ljxrln232jywc2n5.py
# Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_10 => add_13
#   hidden_3 => convert_element_type_49
#   hidden_states_1 => add_4
#   hidden_states_4 => add_7
#   hidden_states_6 => add_12
#   hidden_states_7 => mul_20
#   mul_15 => mul_19
#   pow_4 => pow_4
#   rsqrt_3 => rsqrt_3
#   to_14 => convert_element_type_48
#   variance_3 => mean_3
# Graph fragment:
#   %arg5_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %buf48 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf48]
#   %arg19_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg19_1]
#   %add_4 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg5_1, %mm_3), kwargs = {})
#   %add_7 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %mm_6), kwargs = {})
#   %add_12 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %mm_10), kwargs = {})
#   %convert_element_type_48 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_12, torch.float32), kwargs = {})
#   %pow_4 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_48, 2), kwargs = {})
#   %mean_3 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
#   %add_13 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
#   %rsqrt_3 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_13,), kwargs = {})
#   %mul_19 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_12, %rsqrt_3), kwargs = {})
#   %convert_element_type_49 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_19, torch.bfloat16), kwargs = {})
#   %mul_20 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_49, %arg19_1), kwargs = {})
#   return %buf48,%mul_20
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 9, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 28672}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp7 = tmp6.to(tl.float32)
        tmp8 = tmp7 * tmp7
        tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
        tmp11 = _tmp10 + tmp9
        _tmp10 = tl.where(r0_mask, tmp11, _tmp10)
    tmp10 = tl.sum(_tmp10, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp12 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp17 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp27 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp14 = tmp12 + tmp13
        tmp16 = tmp14 + tmp15
        tmp18 = tmp16 + tmp17
        tmp19 = tmp18.to(tl.float32)
        tmp20 = tl.full([1, 1], 2048.0, tl.float32)
        tmp21 = (tmp10 / tmp20)
        tmp22 = tl.full([1, 1], 1e-05, tl.float32)
        tmp23 = tmp21 + tmp22
        tmp24 = libdevice.rsqrt(tmp23)
        tmp25 = tmp19 * tmp24
        tmp26 = tmp25.to(tl.float32)
        tmp28 = tmp26 * tmp27
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp28, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yl/cylbdkvxyjduut5jxq7eculanrjmf6gukd3vra4hjml6srpeh6rp.py
# Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_12 => add_16
#   hidden_4 => convert_element_type_59
#   hidden_states_1 => add_4
#   hidden_states_10 => mul_23
#   hidden_states_4 => add_7
#   hidden_states_6 => add_12
#   hidden_states_9 => add_15
#   mul_18 => mul_22
#   pow_5 => pow_5
#   rsqrt_4 => rsqrt_4
#   to_16 => convert_element_type_58
#   variance_4 => mean_4
# Graph fragment:
#   %arg5_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg5_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %mm_13 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %add_15 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_15]
#   %buf55 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf55]
#   %arg23_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg23_1]
#   %add_4 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg5_1, %mm_3), kwargs = {})
#   %add_7 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %mm_6), kwargs = {})
#   %add_12 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_7, %mm_10), kwargs = {})
#   %add_15 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_12, %mm_13), kwargs = {})
#   %convert_element_type_58 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_15, torch.float32), kwargs = {})
#   %pow_5 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_58, 2), kwargs = {})
#   %mean_4 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
#   %add_16 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
#   %rsqrt_4 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_16,), kwargs = {})
#   %mul_22 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_15, %rsqrt_4), kwargs = {})
#   %convert_element_type_59 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_22, torch.bfloat16), kwargs = {})
#   %mul_23 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_59, %arg23_1), kwargs = {})
#   return %add_15,%buf55,%mul_23
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_20 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_20', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_20', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 40960}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_20(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp8 = tmp6 + tmp7
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp9 * tmp9
        tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
        tmp13 = _tmp12 + tmp11
        _tmp12 = tl.where(r0_mask, tmp13, _tmp12)
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp8, r0_mask)
    tmp12 = tl.sum(_tmp12, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp14 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp12 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/27/c274p323hor2pyzvusxdyc4rdzpbp6yqux6rpu37b5r5vpquysnw.py
# Topologically Sorted Source Nodes: [cos, sin, view_7, key_states_7, k_2, chunk_5, setitem_4, mul_22, neg_5, cat_5, mul_23, k_embed_2, key_states_8], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_5 => cat_5
#   chunk_5 => split_5
#   cos => index
#   k_2 => convert_element_type_67
#   k_embed_2 => add_18
#   key_states_7 => permute_28
#   key_states_8 => convert_element_type_69
#   mul_22 => mul_26
#   mul_23 => mul_27
#   neg_5 => neg_7
#   setitem_4 => index_put_4, select_38, select_39, view_31
#   sin => index_1
#   view_7 => view_29
# Graph fragment:
#   %select_scatter_default_7 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_29 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_15, [1, 1, 2, 128]), kwargs = {})
#   %permute_28 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_29, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_67 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_28, torch.float32), kwargs = {})
#   %split_5 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_67, 64, -1), kwargs = {})
#   %select_38 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_39 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_38, 0, 2), kwargs = {})
#   %mul_26 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_67, %index), kwargs = {})
#   %neg_7 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_11,), kwargs = {})
#   %cat_5 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_7, %getitem_10], -1), kwargs = {})
#   %mul_27 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_5, %index_1), kwargs = {})
#   %add_18 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_26, %mul_27), kwargs = {})
#   %convert_element_type_69 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_18, torch.bfloat16), kwargs = {})
#   %view_31 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_69, [2, 1, 128]), kwargs = {})
#   %index_put_4 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_39, [None, None, %arg1_1], %view_31), kwargs = {})
#   return %index_put_4
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_21 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_21', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_21(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (4194304 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/d2/cd2a3cacahjbfsp3yuur5l43wiyjtn7r2ybnvex3osd5e3dd7daq.py
# Topologically Sorted Source Nodes: [setitem_5, view_8, value_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_5 => index_put_5, select_43, select_44, view_32
#   value_states_5 => permute_29
#   view_8 => view_30
# Graph fragment:
#   %buf60 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf60]
#   %select_scatter_default_7 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_int_4 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_scatter_default_8 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_4, %index_put_4, 0, 2), kwargs = {})
#   %select_scatter_default_9 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_7, %select_scatter_default_8, 0, 0), kwargs = {})
#   %select_43 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_9, 0, 1), kwargs = {})
#   %select_44 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_43, 0, 2), kwargs = {})
#   %view_30 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_16, [1, 1, 2, 128]), kwargs = {})
#   %permute_29 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_30, [0, 2, 1, 3]), kwargs = {})
#   %view_32 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_29, [2, 1, 128]), kwargs = {})
#   %index_put_5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_44, [None, None, %arg1_1], %view_32), kwargs = {})
#   return %index_put_5
triton_poi_fused_index_put_select_transpose_view_22 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_22', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_22(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (4194304 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (62914560 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 2, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/th/cthatznbmyaifp4mrig267ep5lw2ozw7zstsdkpvysihzaafigps.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf63 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf63]
#   %buf60 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf60]
#   %select_scatter_default_7 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_int_4 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_scatter_default_8 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_4, %index_put_4, 0, 2), kwargs = {})
#   %select_scatter_default_9 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_7, %select_scatter_default_8, 0, 0), kwargs = {})
#   %select_int_5 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_9, 0, 1), kwargs = {})
#   %select_scatter_default_10 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_5, %index_put_5, 0, 2), kwargs = {})
#   %select_scatter_default_11 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_9, %select_scatter_default_10, 0, 1), kwargs = {})
#   return %select_scatter_default_11
triton_poi_fused_23 = async_compile.triton('triton_poi_fused_23', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_23(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 2, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ef/cefax5weyhzncviqs4mwgmmbovfa4spa4p4gbhzkmqwbby24vqub.py
# Topologically Sorted Source Nodes: [attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_8 => clone_4, convert_element_type_71, expand_12, mul_29, permute_30, select_48, select_49, unsqueeze_4, view_33
# Graph fragment:
#   %select_scatter_default_11 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_48 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_49 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_48, 0, 2), kwargs = {})
#   %convert_element_type_71 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_49, torch.float32), kwargs = {})
#   %unsqueeze_4 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_71, 2), kwargs = {})
#   %expand_12 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_4, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_4 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_12,), kwargs = {memory_format: torch.contiguous_format})
#   %view_33 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_4, [1, 16, 8192, 128]), kwargs = {})
#   %permute_30 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_33, [0, 1, 3, 2]), kwargs = {})
#   %mul_29 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_30, 0.29730177875068026), kwargs = {})
#   return %expand_15
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_24 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_24', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_24(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (4194304 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/u6/cu6eeltfdrpwdfvtvvqjyjzmspt4ah33iwyw2rcneplbiowz3wug.py
# Topologically Sorted Source Nodes: [setitem_5, attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_8 => clone_5, convert_element_type_72, expand_13, unsqueeze_5
#   setitem_5 => select_46, select_47
# Graph fragment:
#   %select_scatter_default_11 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_46 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 1), kwargs = {})
#   %select_47 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_46, 0, 2), kwargs = {})
#   %convert_element_type_72 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_47, torch.float32), kwargs = {})
#   %unsqueeze_5 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_72, 2), kwargs = {})
#   %expand_13 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_5, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_5 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_13,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_5
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_25 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_25', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_25(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (62914560 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/xe/cxec2amtqtm426iprxikjgvb3rcumssb4oyniae4csudx25b6vhb.py
# Topologically Sorted Source Nodes: [cos, sin, view_10, key_states_10, k_3, chunk_7, setitem_6, mul_31, neg_7, cat_7, mul_32, k_embed_3, key_states_11], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_7 => cat_7
#   chunk_7 => split_7
#   cos => index
#   k_3 => convert_element_type_96
#   k_embed_3 => add_26
#   key_states_10 => permute_40
#   key_states_11 => convert_element_type_98
#   mul_31 => mul_37
#   mul_32 => mul_38
#   neg_7 => neg_10
#   setitem_6 => index_put_6, select_56, select_57, view_45
#   sin => index_1
#   view_10 => view_43
# Graph fragment:
#   %select_scatter_default_11 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_43 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_22, [1, 1, 2, 128]), kwargs = {})
#   %permute_40 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_43, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_96 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_40, torch.float32), kwargs = {})
#   %split_7 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_96, 64, -1), kwargs = {})
#   %select_56 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_57 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_56, 0, 3), kwargs = {})
#   %mul_37 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_96, %index), kwargs = {})
#   %neg_10 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_15,), kwargs = {})
#   %cat_7 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_10, %getitem_14], -1), kwargs = {})
#   %mul_38 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_7, %index_1), kwargs = {})
#   %add_26 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_37, %mul_38), kwargs = {})
#   %convert_element_type_98 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_26, torch.bfloat16), kwargs = {})
#   %view_45 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_98, [2, 1, 128]), kwargs = {})
#   %index_put_6 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_57, [None, None, %arg1_1], %view_45), kwargs = {})
#   return %index_put_6
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_26 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_26', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_26', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_26(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (6291456 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ky/ckyihk7urn7vm7sgnbv7suukhaikl5vrevedw45yimjijuchmfby.py
# Topologically Sorted Source Nodes: [setitem_7, view_11, value_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_7 => index_put_7, select_61, select_62, view_46
#   value_states_7 => permute_41
#   view_11 => view_44
# Graph fragment:
#   %buf87 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf87]
#   %select_scatter_default_11 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_int_6 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_scatter_default_12 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_6, %index_put_6, 0, 3), kwargs = {})
#   %select_scatter_default_13 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_11, %select_scatter_default_12, 0, 0), kwargs = {})
#   %select_61 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_13, 0, 1), kwargs = {})
#   %select_62 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_61, 0, 3), kwargs = {})
#   %view_44 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_23, [1, 1, 2, 128]), kwargs = {})
#   %permute_41 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_44, [0, 2, 1, 3]), kwargs = {})
#   %view_46 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_41, [2, 1, 128]), kwargs = {})
#   %index_put_7 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_62, [None, None, %arg1_1], %view_46), kwargs = {})
#   return %index_put_7
triton_poi_fused_index_put_select_transpose_view_27 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_27', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_27(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (6291456 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (65011712 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 3, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4x/c4xwyv2sf7dmk3snw7pz4iqx4bnmoj7cd47cerbmc7zktsxjivip.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf90 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf90]
#   %buf87 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf87]
#   %select_scatter_default_11 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_int_6 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_scatter_default_12 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_6, %index_put_6, 0, 3), kwargs = {})
#   %select_scatter_default_13 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_11, %select_scatter_default_12, 0, 0), kwargs = {})
#   %select_int_7 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_13, 0, 1), kwargs = {})
#   %select_scatter_default_14 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_7, %index_put_7, 0, 3), kwargs = {})
#   %select_scatter_default_15 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_13, %select_scatter_default_14, 0, 1), kwargs = {})
#   return %select_scatter_default_15
triton_poi_fused_28 = async_compile.triton('triton_poi_fused_28', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_28(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 3, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/xm/cxmi2jez3ce254rncfutx5x72uimatmqrzxq3zkmbvzhlhzhmm55.py
# Topologically Sorted Source Nodes: [attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_12 => clone_6, convert_element_type_100, expand_18, mul_40, permute_42, select_66, select_67, unsqueeze_6, view_47
# Graph fragment:
#   %select_scatter_default_15 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_66 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_67 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_66, 0, 3), kwargs = {})
#   %convert_element_type_100 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_67, torch.float32), kwargs = {})
#   %unsqueeze_6 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_100, 2), kwargs = {})
#   %expand_18 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_6, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_6 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_18,), kwargs = {memory_format: torch.contiguous_format})
#   %view_47 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_6, [1, 16, 8192, 128]), kwargs = {})
#   %permute_42 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_47, [0, 1, 3, 2]), kwargs = {})
#   %mul_40 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_42, 0.29730177875068026), kwargs = {})
#   return %expand_21
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_29 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_29', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_29(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (6291456 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/mo/cmo2n5y3sbd6okq46ycivt257w62mikhfyfpztopradoqs2cwmpp.py
# Topologically Sorted Source Nodes: [setitem_7, attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_12 => clone_7, convert_element_type_101, expand_19, unsqueeze_7
#   setitem_7 => select_64, select_65
# Graph fragment:
#   %select_scatter_default_15 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_64 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 1), kwargs = {})
#   %select_65 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_64, 0, 3), kwargs = {})
#   %convert_element_type_101 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_65, torch.float32), kwargs = {})
#   %unsqueeze_7 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_101, 2), kwargs = {})
#   %expand_19 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_7, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_7 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_19,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_7
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_30 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_30', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_30', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_30(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (65011712 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/a6/ca6f66k3bbyuyolzrw4b6merrnrc4pyoj636ldz7eeure2o3b22q.py
# Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, hidden_states_16, hidden_states_19, to_32, pow_9, variance_8, add_24, rsqrt_8, mul_36, hidden_8, hidden_states_20], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_24 => add_32
#   hidden_8 => convert_element_type_117
#   hidden_states_11 => add_20
#   hidden_states_14 => add_23
#   hidden_states_16 => add_28
#   hidden_states_19 => add_31
#   hidden_states_20 => mul_45
#   mul_36 => mul_44
#   pow_9 => pow_9
#   rsqrt_8 => rsqrt_8
#   to_32 => convert_element_type_116
#   variance_8 => mean_8
# Graph fragment:
#   %add_15 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_15]
#   %mm_17 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %mm_20 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_20]
#   %mm_24 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_24]
#   %mm_27 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_27]
#   %add_31 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_31]
#   %buf110 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf110]
#   %arg41_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg41_1]
#   %add_20 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_15, %mm_17), kwargs = {})
#   %add_23 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_20, %mm_20), kwargs = {})
#   %add_28 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_23, %mm_24), kwargs = {})
#   %add_31 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_28, %mm_27), kwargs = {})
#   %convert_element_type_116 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_31, torch.float32), kwargs = {})
#   %pow_9 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_116, 2), kwargs = {})
#   %mean_8 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_9, [-1], True), kwargs = {})
#   %add_32 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_8, 1e-05), kwargs = {})
#   %rsqrt_8 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_32,), kwargs = {})
#   %mul_44 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_31, %rsqrt_8), kwargs = {})
#   %convert_element_type_117 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_44, torch.bfloat16), kwargs = {})
#   %mul_45 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_117, %arg41_1), kwargs = {})
#   return %add_31,%buf110,%mul_45
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 40960}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp8 = tmp6 + tmp7
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp9 * tmp9
        tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
        tmp13 = _tmp12 + tmp11
        _tmp12 = tl.where(r0_mask, tmp13, _tmp12)
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp8, r0_mask)
    tmp12 = tl.sum(_tmp12, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp14 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp12 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/pf/cpfgu4on7pkoupupyq6xw6zoyqv3jrfvo4f2u4wtqheges3klaww.py
# Topologically Sorted Source Nodes: [cos, sin, view_13, key_states_13, k_4, chunk_9, setitem_8, mul_40, neg_9, cat_9, mul_41, k_embed_4, key_states_14], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_9 => cat_9
#   chunk_9 => split_9
#   cos => index
#   k_4 => convert_element_type_125
#   k_embed_4 => add_34
#   key_states_13 => permute_52
#   key_states_14 => convert_element_type_127
#   mul_40 => mul_48
#   mul_41 => mul_49
#   neg_9 => neg_13
#   setitem_8 => index_put_8, select_74, select_75, view_59
#   sin => index_1
#   view_13 => view_57
# Graph fragment:
#   %select_scatter_default_15 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_57 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_29, [1, 1, 2, 128]), kwargs = {})
#   %permute_52 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_57, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_125 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_52, torch.float32), kwargs = {})
#   %split_9 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_125, 64, -1), kwargs = {})
#   %select_74 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_75 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_74, 0, 4), kwargs = {})
#   %mul_48 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_125, %index), kwargs = {})
#   %neg_13 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_19,), kwargs = {})
#   %cat_9 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_13, %getitem_18], -1), kwargs = {})
#   %mul_49 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_9, %index_1), kwargs = {})
#   %add_34 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_48, %mul_49), kwargs = {})
#   %convert_element_type_127 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_34, torch.bfloat16), kwargs = {})
#   %view_59 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_127, [2, 1, 128]), kwargs = {})
#   %index_put_8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_75, [None, None, %arg1_1], %view_59), kwargs = {})
#   return %index_put_8
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_32 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_32', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_32(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (8388608 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/rh/crhhtlzwq4rmlq6kyvmoj5dqfmxc2i6h47lkpykjpyawo4br2mse.py
# Topologically Sorted Source Nodes: [setitem_9, view_14, value_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_9 => index_put_9, select_79, select_80, view_60
#   value_states_9 => permute_53
#   view_14 => view_58
# Graph fragment:
#   %buf115 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf115]
#   %select_scatter_default_15 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_int_8 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_scatter_default_16 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_8, %index_put_8, 0, 4), kwargs = {})
#   %select_scatter_default_17 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_15, %select_scatter_default_16, 0, 0), kwargs = {})
#   %select_79 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_17, 0, 1), kwargs = {})
#   %select_80 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_79, 0, 4), kwargs = {})
#   %view_58 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_30, [1, 1, 2, 128]), kwargs = {})
#   %permute_53 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_58, [0, 2, 1, 3]), kwargs = {})
#   %view_60 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_53, [2, 1, 128]), kwargs = {})
#   %index_put_9 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_80, [None, None, %arg1_1], %view_60), kwargs = {})
#   return %index_put_9
triton_poi_fused_index_put_select_transpose_view_33 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_33', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_33(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (8388608 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (67108864 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 4, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yv/cyvjkfuswd7xr6oc7kcyvu3xul7eh4xh4kvx5v5bn2lz6uiccbw6.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf118 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf118]
#   %buf115 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf115]
#   %select_scatter_default_15 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_int_8 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_scatter_default_16 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_8, %index_put_8, 0, 4), kwargs = {})
#   %select_scatter_default_17 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_15, %select_scatter_default_16, 0, 0), kwargs = {})
#   %select_int_9 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_17, 0, 1), kwargs = {})
#   %select_scatter_default_18 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_9, %index_put_9, 0, 4), kwargs = {})
#   %select_scatter_default_19 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_17, %select_scatter_default_18, 0, 1), kwargs = {})
#   return %select_scatter_default_19
triton_poi_fused_34 = async_compile.triton('triton_poi_fused_34', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_34(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 4, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/wy/cwytn7aszvyd4hdekat6iwpkasbyos3ps2q72tdnl6hozlmb6kkc.py
# Topologically Sorted Source Nodes: [attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_16 => clone_8, convert_element_type_129, expand_24, mul_51, permute_54, select_84, select_85, unsqueeze_8, view_61
# Graph fragment:
#   %select_scatter_default_19 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_84 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_85 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_84, 0, 4), kwargs = {})
#   %convert_element_type_129 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_85, torch.float32), kwargs = {})
#   %unsqueeze_8 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_129, 2), kwargs = {})
#   %expand_24 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_8, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_8 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_24,), kwargs = {memory_format: torch.contiguous_format})
#   %view_61 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_8, [1, 16, 8192, 128]), kwargs = {})
#   %permute_54 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_61, [0, 1, 3, 2]), kwargs = {})
#   %mul_51 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_54, 0.29730177875068026), kwargs = {})
#   return %expand_27
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_35 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_35', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_35(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (8388608 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/cy/ccyf6u747lbyznwiijdddpbhwcihayxt2ms7uls3rs3utw2p2p2v.py
# Topologically Sorted Source Nodes: [setitem_9, attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_16 => clone_9, convert_element_type_130, expand_25, unsqueeze_9
#   setitem_9 => select_82, select_83
# Graph fragment:
#   %select_scatter_default_19 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_82 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 1), kwargs = {})
#   %select_83 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_82, 0, 4), kwargs = {})
#   %convert_element_type_130 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_83, torch.float32), kwargs = {})
#   %unsqueeze_9 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_130, 2), kwargs = {})
#   %expand_25 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_9, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_9 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_25,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_9
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_36 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_36', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_36', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_36(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (67108864 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/si/csizpfjonjqukbgsvw7p36pyjjeorsam3g4jsh27sd7rd5adq4gh.py
# Topologically Sorted Source Nodes: [cos, sin, view_16, key_states_16, k_5, chunk_11, setitem_10, mul_49, neg_11, cat_11, mul_50, k_embed_5, key_states_17], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_11 => cat_11
#   chunk_11 => split_11
#   cos => index
#   k_5 => convert_element_type_154
#   k_embed_5 => add_42
#   key_states_16 => permute_64
#   key_states_17 => convert_element_type_156
#   mul_49 => mul_59
#   mul_50 => mul_60
#   neg_11 => neg_16
#   setitem_10 => index_put_10, select_92, select_93, view_73
#   sin => index_1
#   view_16 => view_71
# Graph fragment:
#   %select_scatter_default_19 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_71 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_36, [1, 1, 2, 128]), kwargs = {})
#   %permute_64 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_71, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_154 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_64, torch.float32), kwargs = {})
#   %split_11 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_154, 64, -1), kwargs = {})
#   %select_92 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_93 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_92, 0, 5), kwargs = {})
#   %mul_59 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_154, %index), kwargs = {})
#   %neg_16 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_23,), kwargs = {})
#   %cat_11 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_16, %getitem_22], -1), kwargs = {})
#   %mul_60 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_11, %index_1), kwargs = {})
#   %add_42 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_59, %mul_60), kwargs = {})
#   %convert_element_type_156 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_42, torch.bfloat16), kwargs = {})
#   %view_73 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_156, [2, 1, 128]), kwargs = {})
#   %index_put_10 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_93, [None, None, %arg1_1], %view_73), kwargs = {})
#   return %index_put_10
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_37 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_37', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_37(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (10485760 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/w3/cw3525pxgipst4gvuvao4diuj4bswrtooanwvnba47dhy2qt2qwl.py
# Topologically Sorted Source Nodes: [setitem_11, view_17, value_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_11 => index_put_11, select_97, select_98, view_74
#   value_states_11 => permute_65
#   view_17 => view_72
# Graph fragment:
#   %buf142 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf142]
#   %select_scatter_default_19 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_int_10 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_scatter_default_20 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_10, %index_put_10, 0, 5), kwargs = {})
#   %select_scatter_default_21 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_19, %select_scatter_default_20, 0, 0), kwargs = {})
#   %select_97 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_21, 0, 1), kwargs = {})
#   %select_98 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_97, 0, 5), kwargs = {})
#   %view_72 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_37, [1, 1, 2, 128]), kwargs = {})
#   %permute_65 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_72, [0, 2, 1, 3]), kwargs = {})
#   %view_74 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_65, [2, 1, 128]), kwargs = {})
#   %index_put_11 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_98, [None, None, %arg1_1], %view_74), kwargs = {})
#   return %index_put_11
triton_poi_fused_index_put_select_transpose_view_38 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_38', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_38', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_38(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (10485760 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (69206016 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 5, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4d/c4d3p2zjmeigsicvnbwveoa45mb7i24ijsusny6idtgcrjp2bxxh.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf145 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf145]
#   %buf142 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf142]
#   %select_scatter_default_19 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_int_10 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_scatter_default_20 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_10, %index_put_10, 0, 5), kwargs = {})
#   %select_scatter_default_21 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_19, %select_scatter_default_20, 0, 0), kwargs = {})
#   %select_int_11 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_21, 0, 1), kwargs = {})
#   %select_scatter_default_22 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_11, %index_put_11, 0, 5), kwargs = {})
#   %select_scatter_default_23 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_21, %select_scatter_default_22, 0, 1), kwargs = {})
#   return %select_scatter_default_23
triton_poi_fused_39 = async_compile.triton('triton_poi_fused_39', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_39(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 5, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/g7/cg7xz5acfes2hw7sesncc2ptmj2c5eb2moec33ukq3gyvqznnftf.py
# Topologically Sorted Source Nodes: [attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_20 => clone_10, convert_element_type_158, expand_30, mul_62, permute_66, select_102, select_103, unsqueeze_10, view_75
# Graph fragment:
#   %select_scatter_default_23 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_102 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_103 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_102, 0, 5), kwargs = {})
#   %convert_element_type_158 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_103, torch.float32), kwargs = {})
#   %unsqueeze_10 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_158, 2), kwargs = {})
#   %expand_30 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_10, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_10 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_30,), kwargs = {memory_format: torch.contiguous_format})
#   %view_75 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_10, [1, 16, 8192, 128]), kwargs = {})
#   %permute_66 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_75, [0, 1, 3, 2]), kwargs = {})
#   %mul_62 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_66, 0.29730177875068026), kwargs = {})
#   return %expand_33
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_40 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_40', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_40', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_40(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (10485760 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/52/c52muwshie75l34bqk3yozmsrmyg3ma4xgg6ui4ynun2fclvutrl.py
# Topologically Sorted Source Nodes: [setitem_11, attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_20 => clone_11, convert_element_type_159, expand_31, unsqueeze_11
#   setitem_11 => select_100, select_101
# Graph fragment:
#   %select_scatter_default_23 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_100 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 1), kwargs = {})
#   %select_101 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_100, 0, 5), kwargs = {})
#   %convert_element_type_159 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_101, torch.float32), kwargs = {})
#   %unsqueeze_11 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_159, 2), kwargs = {})
#   %expand_31 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_11, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_11 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_31,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_11
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_41 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_41', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_41(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (69206016 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7b/c7b3cmc2jzcolduth4u4y4n2hk4xes3jslnbfnlz4rthnsfndet4.py
# Topologically Sorted Source Nodes: [cos, sin, view_19, key_states_19, k_6, chunk_13, setitem_12, mul_58, neg_13, cat_13, mul_59, k_embed_6, key_states_20], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_13 => cat_13
#   chunk_13 => split_13
#   cos => index
#   k_6 => convert_element_type_183
#   k_embed_6 => add_50
#   key_states_19 => permute_76
#   key_states_20 => convert_element_type_185
#   mul_58 => mul_70
#   mul_59 => mul_71
#   neg_13 => neg_19
#   setitem_12 => index_put_12, select_110, select_111, view_87
#   sin => index_1
#   view_19 => view_85
# Graph fragment:
#   %select_scatter_default_23 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_85 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_43, [1, 1, 2, 128]), kwargs = {})
#   %permute_76 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_85, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_183 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_76, torch.float32), kwargs = {})
#   %split_13 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_183, 64, -1), kwargs = {})
#   %select_110 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_111 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_110, 0, 6), kwargs = {})
#   %mul_70 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_183, %index), kwargs = {})
#   %neg_19 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_27,), kwargs = {})
#   %cat_13 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_19, %getitem_26], -1), kwargs = {})
#   %mul_71 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_13, %index_1), kwargs = {})
#   %add_50 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_70, %mul_71), kwargs = {})
#   %convert_element_type_185 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_50, torch.bfloat16), kwargs = {})
#   %view_87 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_185, [2, 1, 128]), kwargs = {})
#   %index_put_12 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_111, [None, None, %arg1_1], %view_87), kwargs = {})
#   return %index_put_12
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_42 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_42', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_42(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (12582912 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/3z/c3zczeyq2lbhsawjrwf3awotizydyfudrpllstmbcvabp43eoqn4.py
# Topologically Sorted Source Nodes: [setitem_13, view_20, value_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_13 => index_put_13, select_115, select_116, view_88
#   value_states_13 => permute_77
#   view_20 => view_86
# Graph fragment:
#   %buf170 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf170]
#   %select_scatter_default_23 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_int_12 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_scatter_default_24 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_12, %index_put_12, 0, 6), kwargs = {})
#   %select_scatter_default_25 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_23, %select_scatter_default_24, 0, 0), kwargs = {})
#   %select_115 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_25, 0, 1), kwargs = {})
#   %select_116 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_115, 0, 6), kwargs = {})
#   %view_86 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_44, [1, 1, 2, 128]), kwargs = {})
#   %permute_77 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_86, [0, 2, 1, 3]), kwargs = {})
#   %view_88 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_77, [2, 1, 128]), kwargs = {})
#   %index_put_13 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_116, [None, None, %arg1_1], %view_88), kwargs = {})
#   return %index_put_13
triton_poi_fused_index_put_select_transpose_view_43 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_43', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_43(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (12582912 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (71303168 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 6, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/nn/cnn3mwff4ilobazncywddwup23xpjbfdgq5lsuwlvs3un5krczei.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf173 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf173]
#   %buf170 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf170]
#   %select_scatter_default_23 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_int_12 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_scatter_default_24 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_12, %index_put_12, 0, 6), kwargs = {})
#   %select_scatter_default_25 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_23, %select_scatter_default_24, 0, 0), kwargs = {})
#   %select_int_13 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_25, 0, 1), kwargs = {})
#   %select_scatter_default_26 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_13, %index_put_13, 0, 6), kwargs = {})
#   %select_scatter_default_27 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_25, %select_scatter_default_26, 0, 1), kwargs = {})
#   return %select_scatter_default_27
triton_poi_fused_44 = async_compile.triton('triton_poi_fused_44', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_44', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_44(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 6, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/tf/ctfbjfvs7oyhz4vht3kkmilhh2b2tu2pd5m7mvjqmf2y5dkcbzff.py
# Topologically Sorted Source Nodes: [attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_24 => clone_12, convert_element_type_187, expand_36, mul_73, permute_78, select_120, select_121, unsqueeze_12, view_89
# Graph fragment:
#   %select_scatter_default_27 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_120 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_121 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_120, 0, 6), kwargs = {})
#   %convert_element_type_187 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_121, torch.float32), kwargs = {})
#   %unsqueeze_12 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_187, 2), kwargs = {})
#   %expand_36 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_12, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_12 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_36,), kwargs = {memory_format: torch.contiguous_format})
#   %view_89 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_12, [1, 16, 8192, 128]), kwargs = {})
#   %permute_78 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_89, [0, 1, 3, 2]), kwargs = {})
#   %mul_73 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_78, 0.29730177875068026), kwargs = {})
#   return %expand_39
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_45 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_45', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_45', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_45(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (12582912 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/fr/cfrekvnegwirca6j3m5tff4uyaqborl6hebnbj6sayhxtkao6dja.py
# Topologically Sorted Source Nodes: [setitem_13, attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_24 => clone_13, convert_element_type_188, expand_37, unsqueeze_13
#   setitem_13 => select_118, select_119
# Graph fragment:
#   %select_scatter_default_27 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_118 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 1), kwargs = {})
#   %select_119 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_118, 0, 6), kwargs = {})
#   %convert_element_type_188 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_119, torch.float32), kwargs = {})
#   %unsqueeze_13 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_188, 2), kwargs = {})
#   %expand_37 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_13, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_13 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_37,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_13
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_46 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_46', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_46', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_46(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (71303168 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7i/c7illzkvt3mkb42ydwg6gv7bginumvqfqpgwivmujjtipsuvtqrq.py
# Topologically Sorted Source Nodes: [cos, sin, view_22, key_states_22, k_7, chunk_15, setitem_14, mul_67, neg_15, cat_15, mul_68, k_embed_7, key_states_23], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_15 => cat_15
#   chunk_15 => split_15
#   cos => index
#   k_7 => convert_element_type_212
#   k_embed_7 => add_58
#   key_states_22 => permute_88
#   key_states_23 => convert_element_type_214
#   mul_67 => mul_81
#   mul_68 => mul_82
#   neg_15 => neg_22
#   setitem_14 => index_put_14, select_128, select_129, view_101
#   sin => index_1
#   view_22 => view_99
# Graph fragment:
#   %select_scatter_default_27 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_99 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_50, [1, 1, 2, 128]), kwargs = {})
#   %permute_88 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_99, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_212 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_88, torch.float32), kwargs = {})
#   %split_15 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_212, 64, -1), kwargs = {})
#   %select_128 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_129 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_128, 0, 7), kwargs = {})
#   %mul_81 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_212, %index), kwargs = {})
#   %neg_22 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_31,), kwargs = {})
#   %cat_15 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_22, %getitem_30], -1), kwargs = {})
#   %mul_82 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_15, %index_1), kwargs = {})
#   %add_58 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_81, %mul_82), kwargs = {})
#   %convert_element_type_214 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_58, torch.bfloat16), kwargs = {})
#   %view_101 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_214, [2, 1, 128]), kwargs = {})
#   %index_put_14 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_129, [None, None, %arg1_1], %view_101), kwargs = {})
#   return %index_put_14
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_47 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_47', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_47', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_47(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (14680064 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/bo/cborqri6ugf3sufqlffvfvwuiez36x6l7vgxamfxvjrat5l5xpm7.py
# Topologically Sorted Source Nodes: [setitem_15, view_23, value_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_15 => index_put_15, select_133, select_134, view_102
#   value_states_15 => permute_89
#   view_23 => view_100
# Graph fragment:
#   %buf197 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf197]
#   %select_scatter_default_27 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_int_14 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_scatter_default_28 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_14, %index_put_14, 0, 7), kwargs = {})
#   %select_scatter_default_29 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_27, %select_scatter_default_28, 0, 0), kwargs = {})
#   %select_133 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_29, 0, 1), kwargs = {})
#   %select_134 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_133, 0, 7), kwargs = {})
#   %view_100 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_51, [1, 1, 2, 128]), kwargs = {})
#   %permute_89 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_100, [0, 2, 1, 3]), kwargs = {})
#   %view_102 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_89, [2, 1, 128]), kwargs = {})
#   %index_put_15 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_134, [None, None, %arg1_1], %view_102), kwargs = {})
#   return %index_put_15
triton_poi_fused_index_put_select_transpose_view_48 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_48', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_48', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_48(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (14680064 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (73400320 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 7, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/d2/cd2vb43v7lz4utynw2cmlhexr2t5bsnznqe6bfrhc6wv4twl4kbz.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf200 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf200]
#   %buf197 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf197]
#   %select_scatter_default_27 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_int_14 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_scatter_default_28 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_14, %index_put_14, 0, 7), kwargs = {})
#   %select_scatter_default_29 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_27, %select_scatter_default_28, 0, 0), kwargs = {})
#   %select_int_15 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_29, 0, 1), kwargs = {})
#   %select_scatter_default_30 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_15, %index_put_15, 0, 7), kwargs = {})
#   %select_scatter_default_31 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_29, %select_scatter_default_30, 0, 1), kwargs = {})
#   return %select_scatter_default_31
triton_poi_fused_49 = async_compile.triton('triton_poi_fused_49', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_49', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_49(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 7, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/uw/cuwxsilxxc4m7nekrntiyeubdqu3ybabek3fofgduqvouw7wmplv.py
# Topologically Sorted Source Nodes: [attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_28 => clone_14, convert_element_type_216, expand_42, mul_84, permute_90, select_138, select_139, unsqueeze_14, view_103
# Graph fragment:
#   %select_scatter_default_31 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %select_138 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_31, 0, 0), kwargs = {})
#   %select_139 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_138, 0, 7), kwargs = {})
#   %convert_element_type_216 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_139, torch.float32), kwargs = {})
#   %unsqueeze_14 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_216, 2), kwargs = {})
#   %expand_42 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_14, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_14 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_42,), kwargs = {memory_format: torch.contiguous_format})
#   %view_103 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_14, [1, 16, 8192, 128]), kwargs = {})
#   %permute_90 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_103, [0, 1, 3, 2]), kwargs = {})
#   %mul_84 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_90, 0.29730177875068026), kwargs = {})
#   return %expand_45
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_50 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_50', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_50', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_50(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (14680064 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/cy/ccywlx2uc7zuj5m3bcvyrcnvx4hjr7d5l6vrahlrj2h3vpx7ew3i.py
# Topologically Sorted Source Nodes: [setitem_15, attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_28 => clone_15, convert_element_type_217, expand_43, unsqueeze_15
#   setitem_15 => select_136, select_137
# Graph fragment:
#   %select_scatter_default_31 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %select_136 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_31, 0, 1), kwargs = {})
#   %select_137 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_136, 0, 7), kwargs = {})
#   %convert_element_type_217 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_137, torch.float32), kwargs = {})
#   %unsqueeze_15 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_217, 2), kwargs = {})
#   %expand_43 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_15, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_15 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_43,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_15
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_51 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_51', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_51', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_51(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (73400320 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/z6/cz6c3pptffcucxkth5csu2n2co5x2eedynq5rin27vwurfdobxtt.py
# Topologically Sorted Source Nodes: [cos, sin, view_25, key_states_25, k_8, chunk_17, setitem_16, mul_76, neg_17, cat_17, mul_77, k_embed_8, key_states_26], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_17 => cat_17
#   chunk_17 => split_17
#   cos => index
#   k_8 => convert_element_type_241
#   k_embed_8 => add_66
#   key_states_25 => permute_100
#   key_states_26 => convert_element_type_243
#   mul_76 => mul_92
#   mul_77 => mul_93
#   neg_17 => neg_25
#   setitem_16 => index_put_16, select_146, select_147, view_115
#   sin => index_1
#   view_25 => view_113
# Graph fragment:
#   %select_scatter_default_31 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_113 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_57, [1, 1, 2, 128]), kwargs = {})
#   %permute_100 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_113, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_241 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_100, torch.float32), kwargs = {})
#   %split_17 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_241, 64, -1), kwargs = {})
#   %select_146 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_31, 0, 0), kwargs = {})
#   %select_147 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_146, 0, 8), kwargs = {})
#   %mul_92 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_241, %index), kwargs = {})
#   %neg_25 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_35,), kwargs = {})
#   %cat_17 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_25, %getitem_34], -1), kwargs = {})
#   %mul_93 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_17, %index_1), kwargs = {})
#   %add_66 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_92, %mul_93), kwargs = {})
#   %convert_element_type_243 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_66, torch.bfloat16), kwargs = {})
#   %view_115 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_243, [2, 1, 128]), kwargs = {})
#   %index_put_16 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_147, [None, None, %arg1_1], %view_115), kwargs = {})
#   return %index_put_16
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_52 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_52', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_52', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_52(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (16777216 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5f/c5fhbmpxq6lurphk6lt3s75x7yusdwri4egdubb5laiopogianje.py
# Topologically Sorted Source Nodes: [setitem_17, view_26, value_states_17], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_17 => index_put_17, select_151, select_152, view_116
#   value_states_17 => permute_101
#   view_26 => view_114
# Graph fragment:
#   %buf225 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf225]
#   %select_scatter_default_31 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %select_int_16 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_31, 0, 0), kwargs = {})
#   %select_scatter_default_32 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_16, %index_put_16, 0, 8), kwargs = {})
#   %select_scatter_default_33 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_31, %select_scatter_default_32, 0, 0), kwargs = {})
#   %select_151 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_33, 0, 1), kwargs = {})
#   %select_152 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_151, 0, 8), kwargs = {})
#   %view_114 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_58, [1, 1, 2, 128]), kwargs = {})
#   %permute_101 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_114, [0, 2, 1, 3]), kwargs = {})
#   %view_116 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_101, [2, 1, 128]), kwargs = {})
#   %index_put_17 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_152, [None, None, %arg1_1], %view_116), kwargs = {})
#   return %index_put_17
triton_poi_fused_index_put_select_transpose_view_53 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_53', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_53', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_53(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (16777216 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (75497472 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 8, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/gb/cgbpjx3wu4l4rgs5r6ksyuoc5ye4zsjqakeodggohs2d6emz3phl.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf228 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf228]
#   %buf225 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf225]
#   %select_scatter_default_31 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %select_int_16 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_31, 0, 0), kwargs = {})
#   %select_scatter_default_32 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_16, %index_put_16, 0, 8), kwargs = {})
#   %select_scatter_default_33 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_31, %select_scatter_default_32, 0, 0), kwargs = {})
#   %select_int_17 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_33, 0, 1), kwargs = {})
#   %select_scatter_default_34 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_17, %index_put_17, 0, 8), kwargs = {})
#   %select_scatter_default_35 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_33, %select_scatter_default_34, 0, 1), kwargs = {})
#   return %select_scatter_default_35
triton_poi_fused_54 = async_compile.triton('triton_poi_fused_54', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_54', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_54(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 8, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4n/c4n4uixn7jjb4nn5hnjut2zwfu6vzklpdt2p4r5ay6qagcb4gpjb.py
# Topologically Sorted Source Nodes: [attn_output_32], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_32 => clone_16, convert_element_type_245, expand_48, mul_95, permute_102, select_156, select_157, unsqueeze_16, view_117
# Graph fragment:
#   %select_scatter_default_35 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_35]
#   %select_156 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_35, 0, 0), kwargs = {})
#   %select_157 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_156, 0, 8), kwargs = {})
#   %convert_element_type_245 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_157, torch.float32), kwargs = {})
#   %unsqueeze_16 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_245, 2), kwargs = {})
#   %expand_48 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_16, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_16 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_48,), kwargs = {memory_format: torch.contiguous_format})
#   %view_117 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_16, [1, 16, 8192, 128]), kwargs = {})
#   %permute_102 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_117, [0, 1, 3, 2]), kwargs = {})
#   %mul_95 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_102, 0.29730177875068026), kwargs = {})
#   return %expand_51
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_55 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_55', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_55', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_55(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (16777216 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/e3/ce3w7dbgjrqztuq3yga3nhtnnk6shgkjaqcn4h3alrpvpsz6k7qq.py
# Topologically Sorted Source Nodes: [setitem_17, attn_output_32], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_32 => clone_17, convert_element_type_246, expand_49, unsqueeze_17
#   setitem_17 => select_154, select_155
# Graph fragment:
#   %select_scatter_default_35 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_35]
#   %select_154 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_35, 0, 1), kwargs = {})
#   %select_155 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_154, 0, 8), kwargs = {})
#   %convert_element_type_246 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_155, torch.float32), kwargs = {})
#   %unsqueeze_17 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_246, 2), kwargs = {})
#   %expand_49 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_17, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_17 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_49,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_17
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_56 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_56', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_56', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_56(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (75497472 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/xa/cxak4klzo7iofklgmfyhqlucuyhajelqivkhr6fyj2rjrvh5n3zl.py
# Topologically Sorted Source Nodes: [cos, sin, view_28, key_states_28, k_9, chunk_19, setitem_18, mul_85, neg_19, cat_19, mul_86, k_embed_9, key_states_29], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_19 => cat_19
#   chunk_19 => split_19
#   cos => index
#   k_9 => convert_element_type_270
#   k_embed_9 => add_74
#   key_states_28 => permute_112
#   key_states_29 => convert_element_type_272
#   mul_85 => mul_103
#   mul_86 => mul_104
#   neg_19 => neg_28
#   setitem_18 => index_put_18, select_164, select_165, view_129
#   sin => index_1
#   view_28 => view_127
# Graph fragment:
#   %select_scatter_default_35 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_35]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_127 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_64, [1, 1, 2, 128]), kwargs = {})
#   %permute_112 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_127, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_270 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_112, torch.float32), kwargs = {})
#   %split_19 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_270, 64, -1), kwargs = {})
#   %select_164 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_35, 0, 0), kwargs = {})
#   %select_165 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_164, 0, 9), kwargs = {})
#   %mul_103 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_270, %index), kwargs = {})
#   %neg_28 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_39,), kwargs = {})
#   %cat_19 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_28, %getitem_38], -1), kwargs = {})
#   %mul_104 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_19, %index_1), kwargs = {})
#   %add_74 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_103, %mul_104), kwargs = {})
#   %convert_element_type_272 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_74, torch.bfloat16), kwargs = {})
#   %view_129 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_272, [2, 1, 128]), kwargs = {})
#   %index_put_18 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_165, [None, None, %arg1_1], %view_129), kwargs = {})
#   return %index_put_18
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_57 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_57', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_57', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_57(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (18874368 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/wj/cwjjnkvvpcrtficgf5vwgekr4zm4yrkrbaolabwktds45jo4ufxg.py
# Topologically Sorted Source Nodes: [setitem_19, view_29, value_states_19], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_19 => index_put_19, select_169, select_170, view_130
#   value_states_19 => permute_113
#   view_29 => view_128
# Graph fragment:
#   %buf252 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf252]
#   %select_scatter_default_35 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_35]
#   %select_int_18 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_35, 0, 0), kwargs = {})
#   %select_scatter_default_36 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_18, %index_put_18, 0, 9), kwargs = {})
#   %select_scatter_default_37 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_35, %select_scatter_default_36, 0, 0), kwargs = {})
#   %select_169 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_37, 0, 1), kwargs = {})
#   %select_170 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_169, 0, 9), kwargs = {})
#   %view_128 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_65, [1, 1, 2, 128]), kwargs = {})
#   %permute_113 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_128, [0, 2, 1, 3]), kwargs = {})
#   %view_130 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_113, [2, 1, 128]), kwargs = {})
#   %index_put_19 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_170, [None, None, %arg1_1], %view_130), kwargs = {})
#   return %index_put_19
triton_poi_fused_index_put_select_transpose_view_58 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_58', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_58', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_58(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (18874368 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (77594624 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 9, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/bl/cblsewtkq5755qzowvig46smba5f5hjesbud2at274aisaotarew.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf255 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf255]
#   %buf252 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf252]
#   %select_scatter_default_35 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_35]
#   %select_int_18 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_35, 0, 0), kwargs = {})
#   %select_scatter_default_36 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_18, %index_put_18, 0, 9), kwargs = {})
#   %select_scatter_default_37 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_35, %select_scatter_default_36, 0, 0), kwargs = {})
#   %select_int_19 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_37, 0, 1), kwargs = {})
#   %select_scatter_default_38 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_19, %index_put_19, 0, 9), kwargs = {})
#   %select_scatter_default_39 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_37, %select_scatter_default_38, 0, 1), kwargs = {})
#   return %select_scatter_default_39
triton_poi_fused_59 = async_compile.triton('triton_poi_fused_59', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_59', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_59(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 9, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4l/c4lezbsduqrxzidwxsbza5up4yclsismblrwjvxpdnlsf2mkenfo.py
# Topologically Sorted Source Nodes: [attn_output_36], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_36 => clone_18, convert_element_type_274, expand_54, mul_106, permute_114, select_174, select_175, unsqueeze_18, view_131
# Graph fragment:
#   %select_scatter_default_39 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_39]
#   %select_174 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_39, 0, 0), kwargs = {})
#   %select_175 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_174, 0, 9), kwargs = {})
#   %convert_element_type_274 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_175, torch.float32), kwargs = {})
#   %unsqueeze_18 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_274, 2), kwargs = {})
#   %expand_54 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_18, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_18 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_54,), kwargs = {memory_format: torch.contiguous_format})
#   %view_131 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_18, [1, 16, 8192, 128]), kwargs = {})
#   %permute_114 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_131, [0, 1, 3, 2]), kwargs = {})
#   %mul_106 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_114, 0.29730177875068026), kwargs = {})
#   return %expand_57
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_60 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_60', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_60', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_60(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (18874368 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ww/cwwzm4kb5436p6mzjpfhhnrossyy3cfdj43atdtz67jjebxl7smg.py
# Topologically Sorted Source Nodes: [setitem_19, attn_output_36], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_36 => clone_19, convert_element_type_275, expand_55, unsqueeze_19
#   setitem_19 => select_172, select_173
# Graph fragment:
#   %select_scatter_default_39 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_39]
#   %select_172 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_39, 0, 1), kwargs = {})
#   %select_173 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_172, 0, 9), kwargs = {})
#   %convert_element_type_275 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_173, torch.float32), kwargs = {})
#   %unsqueeze_19 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_275, 2), kwargs = {})
#   %expand_55 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_19, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_19 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_55,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_19
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_61 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_61', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_61', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_61(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (77594624 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ro/cro2w2wezgoocnrzkpnay3b5jjnhionr5u3nmeemyw3qe2cfhu36.py
# Topologically Sorted Source Nodes: [cos, sin, view_31, key_states_31, k_10, chunk_21, setitem_20, mul_94, neg_21, cat_21, mul_95, k_embed_10, key_states_32], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_21 => cat_21
#   chunk_21 => split_21
#   cos => index
#   k_10 => convert_element_type_299
#   k_embed_10 => add_82
#   key_states_31 => permute_124
#   key_states_32 => convert_element_type_301
#   mul_94 => mul_114
#   mul_95 => mul_115
#   neg_21 => neg_31
#   setitem_20 => index_put_20, select_182, select_183, view_143
#   sin => index_1
#   view_31 => view_141
# Graph fragment:
#   %select_scatter_default_39 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_39]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_141 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_71, [1, 1, 2, 128]), kwargs = {})
#   %permute_124 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_141, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_299 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_124, torch.float32), kwargs = {})
#   %split_21 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_299, 64, -1), kwargs = {})
#   %select_182 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_39, 0, 0), kwargs = {})
#   %select_183 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_182, 0, 10), kwargs = {})
#   %mul_114 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_299, %index), kwargs = {})
#   %neg_31 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_43,), kwargs = {})
#   %cat_21 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_31, %getitem_42], -1), kwargs = {})
#   %mul_115 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_21, %index_1), kwargs = {})
#   %add_82 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_114, %mul_115), kwargs = {})
#   %convert_element_type_301 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_82, torch.bfloat16), kwargs = {})
#   %view_143 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_301, [2, 1, 128]), kwargs = {})
#   %index_put_20 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_183, [None, None, %arg1_1], %view_143), kwargs = {})
#   return %index_put_20
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_62 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_62', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_62', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_62(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (20971520 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/hk/chk6nvrf5xn4wa43srdkmauy37wpnkodhxvvtk3tnuk7ogsgo6gt.py
# Topologically Sorted Source Nodes: [setitem_21, view_32, value_states_21], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_21 => index_put_21, select_187, select_188, view_144
#   value_states_21 => permute_125
#   view_32 => view_142
# Graph fragment:
#   %buf280 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf280]
#   %select_scatter_default_39 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_39]
#   %select_int_20 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_39, 0, 0), kwargs = {})
#   %select_scatter_default_40 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_20, %index_put_20, 0, 10), kwargs = {})
#   %select_scatter_default_41 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_39, %select_scatter_default_40, 0, 0), kwargs = {})
#   %select_187 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_41, 0, 1), kwargs = {})
#   %select_188 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_187, 0, 10), kwargs = {})
#   %view_142 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_72, [1, 1, 2, 128]), kwargs = {})
#   %permute_125 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_142, [0, 2, 1, 3]), kwargs = {})
#   %view_144 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_125, [2, 1, 128]), kwargs = {})
#   %index_put_21 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_188, [None, None, %arg1_1], %view_144), kwargs = {})
#   return %index_put_21
triton_poi_fused_index_put_select_transpose_view_63 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_63', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_63', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_63(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (20971520 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (79691776 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 10, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/hf/chfh7eqktdiw6utbtvwbbybetgdrbbo2vhh35fq2rf2am7qv4fcj.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf283 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf283]
#   %buf280 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf280]
#   %select_scatter_default_39 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_39]
#   %select_int_20 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_39, 0, 0), kwargs = {})
#   %select_scatter_default_40 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_20, %index_put_20, 0, 10), kwargs = {})
#   %select_scatter_default_41 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_39, %select_scatter_default_40, 0, 0), kwargs = {})
#   %select_int_21 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_41, 0, 1), kwargs = {})
#   %select_scatter_default_42 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_21, %index_put_21, 0, 10), kwargs = {})
#   %select_scatter_default_43 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_41, %select_scatter_default_42, 0, 1), kwargs = {})
#   return %select_scatter_default_43
triton_poi_fused_64 = async_compile.triton('triton_poi_fused_64', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_64', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_64(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 10, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/kv/ckvk44dbftmilp75l66qldwkipn57sc3rg3ic7zl2ksrd2mpmdwo.py
# Topologically Sorted Source Nodes: [attn_output_40], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_40 => clone_20, convert_element_type_303, expand_60, mul_117, permute_126, select_192, select_193, unsqueeze_20, view_145
# Graph fragment:
#   %select_scatter_default_43 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_43]
#   %select_192 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_43, 0, 0), kwargs = {})
#   %select_193 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_192, 0, 10), kwargs = {})
#   %convert_element_type_303 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_193, torch.float32), kwargs = {})
#   %unsqueeze_20 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_303, 2), kwargs = {})
#   %expand_60 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_20, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_20 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_60,), kwargs = {memory_format: torch.contiguous_format})
#   %view_145 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_20, [1, 16, 8192, 128]), kwargs = {})
#   %permute_126 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_145, [0, 1, 3, 2]), kwargs = {})
#   %mul_117 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_126, 0.29730177875068026), kwargs = {})
#   return %expand_63
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_65 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_65', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_65', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_65(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (20971520 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ol/colq2gz4uipshijzj5yk6y3qgw7cs74jybjaqiywbhcxf54l5x5b.py
# Topologically Sorted Source Nodes: [setitem_21, attn_output_40], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_40 => clone_21, convert_element_type_304, expand_61, unsqueeze_21
#   setitem_21 => select_190, select_191
# Graph fragment:
#   %select_scatter_default_43 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_43]
#   %select_190 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_43, 0, 1), kwargs = {})
#   %select_191 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_190, 0, 10), kwargs = {})
#   %convert_element_type_304 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_191, torch.float32), kwargs = {})
#   %unsqueeze_21 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_304, 2), kwargs = {})
#   %expand_61 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_21, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_21 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_61,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_21
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_66 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_66', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_66', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_66(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (79691776 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/i3/ci3f7hsg6peod3gtk7mq6msefntzq3dinoqyikqx6y72hcxc66j2.py
# Topologically Sorted Source Nodes: [cos, sin, view_34, key_states_34, k_11, chunk_23, setitem_22, mul_103, neg_23, cat_23, mul_104, k_embed_11, key_states_35], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_23 => cat_23
#   chunk_23 => split_23
#   cos => index
#   k_11 => convert_element_type_328
#   k_embed_11 => add_90
#   key_states_34 => permute_136
#   key_states_35 => convert_element_type_330
#   mul_103 => mul_125
#   mul_104 => mul_126
#   neg_23 => neg_34
#   setitem_22 => index_put_22, select_200, select_201, view_157
#   sin => index_1
#   view_34 => view_155
# Graph fragment:
#   %select_scatter_default_43 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_43]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_155 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_78, [1, 1, 2, 128]), kwargs = {})
#   %permute_136 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_155, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_328 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_136, torch.float32), kwargs = {})
#   %split_23 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_328, 64, -1), kwargs = {})
#   %select_200 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_43, 0, 0), kwargs = {})
#   %select_201 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_200, 0, 11), kwargs = {})
#   %mul_125 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_328, %index), kwargs = {})
#   %neg_34 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_47,), kwargs = {})
#   %cat_23 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_34, %getitem_46], -1), kwargs = {})
#   %mul_126 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_23, %index_1), kwargs = {})
#   %add_90 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_125, %mul_126), kwargs = {})
#   %convert_element_type_330 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_90, torch.bfloat16), kwargs = {})
#   %view_157 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_330, [2, 1, 128]), kwargs = {})
#   %index_put_22 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_201, [None, None, %arg1_1], %view_157), kwargs = {})
#   return %index_put_22
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_67 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_67', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_67', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_67(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (23068672 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ur/cura4uypxyt4ukq4jh7ucjld5yrnhivxlz7udnq2ny66qc4ab2il.py
# Topologically Sorted Source Nodes: [setitem_23, view_35, value_states_23], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_23 => index_put_23, select_205, select_206, view_158
#   value_states_23 => permute_137
#   view_35 => view_156
# Graph fragment:
#   %buf307 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf307]
#   %select_scatter_default_43 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_43]
#   %select_int_22 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_43, 0, 0), kwargs = {})
#   %select_scatter_default_44 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_22, %index_put_22, 0, 11), kwargs = {})
#   %select_scatter_default_45 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_43, %select_scatter_default_44, 0, 0), kwargs = {})
#   %select_205 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_45, 0, 1), kwargs = {})
#   %select_206 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_205, 0, 11), kwargs = {})
#   %view_156 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_79, [1, 1, 2, 128]), kwargs = {})
#   %permute_137 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_156, [0, 2, 1, 3]), kwargs = {})
#   %view_158 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_137, [2, 1, 128]), kwargs = {})
#   %index_put_23 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_206, [None, None, %arg1_1], %view_158), kwargs = {})
#   return %index_put_23
triton_poi_fused_index_put_select_transpose_view_68 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_68', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_68', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_68(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (23068672 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (81788928 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 11, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/dg/cdgapvjg23ft77qwpuwjacyzpmzasugyiabxrytghkjtu5k6sz6j.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf310 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf310]
#   %buf307 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf307]
#   %select_scatter_default_43 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_43]
#   %select_int_22 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_43, 0, 0), kwargs = {})
#   %select_scatter_default_44 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_22, %index_put_22, 0, 11), kwargs = {})
#   %select_scatter_default_45 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_43, %select_scatter_default_44, 0, 0), kwargs = {})
#   %select_int_23 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_45, 0, 1), kwargs = {})
#   %select_scatter_default_46 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_23, %index_put_23, 0, 11), kwargs = {})
#   %select_scatter_default_47 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_45, %select_scatter_default_46, 0, 1), kwargs = {})
#   return %select_scatter_default_47
triton_poi_fused_69 = async_compile.triton('triton_poi_fused_69', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_69', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_69(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 11, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/b6/cb673czbnr3dm6pipmi34m6g3yjw5k3syfuqnfgjvsjyupzv76m3.py
# Topologically Sorted Source Nodes: [attn_output_44], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_44 => clone_22, convert_element_type_332, expand_66, mul_128, permute_138, select_210, select_211, unsqueeze_22, view_159
# Graph fragment:
#   %select_scatter_default_47 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_47]
#   %select_210 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_47, 0, 0), kwargs = {})
#   %select_211 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_210, 0, 11), kwargs = {})
#   %convert_element_type_332 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_211, torch.float32), kwargs = {})
#   %unsqueeze_22 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_332, 2), kwargs = {})
#   %expand_66 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_22, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_22 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_66,), kwargs = {memory_format: torch.contiguous_format})
#   %view_159 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_22, [1, 16, 8192, 128]), kwargs = {})
#   %permute_138 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_159, [0, 1, 3, 2]), kwargs = {})
#   %mul_128 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_138, 0.29730177875068026), kwargs = {})
#   return %expand_69
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_70 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_70', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_70', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_70(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (23068672 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/6u/c6uzc5t5on36qdesnfqcm2xncuhrfpmmwmi6qxymvaq6mj3pjglq.py
# Topologically Sorted Source Nodes: [setitem_23, attn_output_44], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_44 => clone_23, convert_element_type_333, expand_67, unsqueeze_23
#   setitem_23 => select_208, select_209
# Graph fragment:
#   %select_scatter_default_47 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_47]
#   %select_208 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_47, 0, 1), kwargs = {})
#   %select_209 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_208, 0, 11), kwargs = {})
#   %convert_element_type_333 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_209, torch.float32), kwargs = {})
#   %unsqueeze_23 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_333, 2), kwargs = {})
#   %expand_67 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_23, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_23 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_67,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_23
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_71 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_71', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_71', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_71(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (81788928 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/iv/civ3fnvcrqedwu6szi4itb6gfs6xfmhk2yuhqi4lcgykbjith5cg.py
# Topologically Sorted Source Nodes: [cos, sin, view_37, key_states_37, k_12, chunk_25, setitem_24, mul_112, neg_25, cat_25, mul_113, k_embed_12, key_states_38], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_25 => cat_25
#   chunk_25 => split_25
#   cos => index
#   k_12 => convert_element_type_357
#   k_embed_12 => add_98
#   key_states_37 => permute_148
#   key_states_38 => convert_element_type_359
#   mul_112 => mul_136
#   mul_113 => mul_137
#   neg_25 => neg_37
#   setitem_24 => index_put_24, select_218, select_219, view_171
#   sin => index_1
#   view_37 => view_169
# Graph fragment:
#   %select_scatter_default_47 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_47]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_169 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_85, [1, 1, 2, 128]), kwargs = {})
#   %permute_148 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_169, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_357 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_148, torch.float32), kwargs = {})
#   %split_25 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_357, 64, -1), kwargs = {})
#   %select_218 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_47, 0, 0), kwargs = {})
#   %select_219 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_218, 0, 12), kwargs = {})
#   %mul_136 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_357, %index), kwargs = {})
#   %neg_37 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_51,), kwargs = {})
#   %cat_25 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_37, %getitem_50], -1), kwargs = {})
#   %mul_137 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_25, %index_1), kwargs = {})
#   %add_98 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_136, %mul_137), kwargs = {})
#   %convert_element_type_359 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_98, torch.bfloat16), kwargs = {})
#   %view_171 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_359, [2, 1, 128]), kwargs = {})
#   %index_put_24 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_219, [None, None, %arg1_1], %view_171), kwargs = {})
#   return %index_put_24
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_72 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_72', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_72', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_72(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (25165824 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/i2/ci2gmlx4k2bar26iofzll7grt2d5aygosn57qiou5fb6wrivxtar.py
# Topologically Sorted Source Nodes: [setitem_25, view_38, value_states_25], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_25 => index_put_25, select_223, select_224, view_172
#   value_states_25 => permute_149
#   view_38 => view_170
# Graph fragment:
#   %buf335 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf335]
#   %select_scatter_default_47 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_47]
#   %select_int_24 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_47, 0, 0), kwargs = {})
#   %select_scatter_default_48 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_24, %index_put_24, 0, 12), kwargs = {})
#   %select_scatter_default_49 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_47, %select_scatter_default_48, 0, 0), kwargs = {})
#   %select_223 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_49, 0, 1), kwargs = {})
#   %select_224 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_223, 0, 12), kwargs = {})
#   %view_170 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_86, [1, 1, 2, 128]), kwargs = {})
#   %permute_149 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_170, [0, 2, 1, 3]), kwargs = {})
#   %view_172 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_149, [2, 1, 128]), kwargs = {})
#   %index_put_25 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_224, [None, None, %arg1_1], %view_172), kwargs = {})
#   return %index_put_25
triton_poi_fused_index_put_select_transpose_view_73 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_73', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_73', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_73(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (25165824 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (83886080 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 12, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/oy/coybzfmkg4pmloudfju5mgkuynz7gvpxlnvwsmqxj62df72wauxj.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf338 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf338]
#   %buf335 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf335]
#   %select_scatter_default_47 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_47]
#   %select_int_24 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_47, 0, 0), kwargs = {})
#   %select_scatter_default_48 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_24, %index_put_24, 0, 12), kwargs = {})
#   %select_scatter_default_49 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_47, %select_scatter_default_48, 0, 0), kwargs = {})
#   %select_int_25 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_49, 0, 1), kwargs = {})
#   %select_scatter_default_50 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_25, %index_put_25, 0, 12), kwargs = {})
#   %select_scatter_default_51 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_49, %select_scatter_default_50, 0, 1), kwargs = {})
#   return %select_scatter_default_51
triton_poi_fused_74 = async_compile.triton('triton_poi_fused_74', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_74', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_74(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 12, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yt/cytgo3ysk4dmi3cy6q7fnpe5swopson2bpbfznp7dn2cflvolltc.py
# Topologically Sorted Source Nodes: [attn_output_48], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_48 => clone_24, convert_element_type_361, expand_72, mul_139, permute_150, select_228, select_229, unsqueeze_24, view_173
# Graph fragment:
#   %select_scatter_default_51 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_51]
#   %select_228 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_51, 0, 0), kwargs = {})
#   %select_229 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_228, 0, 12), kwargs = {})
#   %convert_element_type_361 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_229, torch.float32), kwargs = {})
#   %unsqueeze_24 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_361, 2), kwargs = {})
#   %expand_72 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_24, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_24 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_72,), kwargs = {memory_format: torch.contiguous_format})
#   %view_173 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_24, [1, 16, 8192, 128]), kwargs = {})
#   %permute_150 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_173, [0, 1, 3, 2]), kwargs = {})
#   %mul_139 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_150, 0.29730177875068026), kwargs = {})
#   return %expand_75
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_75 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_75', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_75', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_75(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (25165824 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/o7/co7c6onejqsueo234tzb2mfekt7kdfc7pszt5xwg5wzov3oeexq4.py
# Topologically Sorted Source Nodes: [setitem_25, attn_output_48], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_48 => clone_25, convert_element_type_362, expand_73, unsqueeze_25
#   setitem_25 => select_226, select_227
# Graph fragment:
#   %select_scatter_default_51 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_51]
#   %select_226 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_51, 0, 1), kwargs = {})
#   %select_227 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_226, 0, 12), kwargs = {})
#   %convert_element_type_362 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_227, torch.float32), kwargs = {})
#   %unsqueeze_25 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_362, 2), kwargs = {})
#   %expand_73 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_25, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_25 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_73,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_25
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_76 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_76', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_76', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_76(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (83886080 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ty/ctyntlbhee4l4o5o6ej24gblet6hcxvyb6oj4ft6ugxlc6gk5jpg.py
# Topologically Sorted Source Nodes: [cos, sin, view_40, key_states_40, k_13, chunk_27, setitem_26, mul_121, neg_27, cat_27, mul_122, k_embed_13, key_states_41], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_27 => cat_27
#   chunk_27 => split_27
#   cos => index
#   k_13 => convert_element_type_386
#   k_embed_13 => add_106
#   key_states_40 => permute_160
#   key_states_41 => convert_element_type_388
#   mul_121 => mul_147
#   mul_122 => mul_148
#   neg_27 => neg_40
#   setitem_26 => index_put_26, select_236, select_237, view_185
#   sin => index_1
#   view_40 => view_183
# Graph fragment:
#   %select_scatter_default_51 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_51]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_183 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_92, [1, 1, 2, 128]), kwargs = {})
#   %permute_160 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_183, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_386 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_160, torch.float32), kwargs = {})
#   %split_27 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_386, 64, -1), kwargs = {})
#   %select_236 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_51, 0, 0), kwargs = {})
#   %select_237 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_236, 0, 13), kwargs = {})
#   %mul_147 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_386, %index), kwargs = {})
#   %neg_40 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_55,), kwargs = {})
#   %cat_27 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_40, %getitem_54], -1), kwargs = {})
#   %mul_148 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_27, %index_1), kwargs = {})
#   %add_106 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_147, %mul_148), kwargs = {})
#   %convert_element_type_388 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_106, torch.bfloat16), kwargs = {})
#   %view_185 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_388, [2, 1, 128]), kwargs = {})
#   %index_put_26 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_237, [None, None, %arg1_1], %view_185), kwargs = {})
#   return %index_put_26
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_77 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_77', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_77', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_77(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (27262976 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/rn/crnejrk2cbewsb2cyse47ka7yul7k6fl275asuklod5ivnph7o5h.py
# Topologically Sorted Source Nodes: [setitem_27, view_41, value_states_27], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_27 => index_put_27, select_241, select_242, view_186
#   value_states_27 => permute_161
#   view_41 => view_184
# Graph fragment:
#   %buf362 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf362]
#   %select_scatter_default_51 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_51]
#   %select_int_26 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_51, 0, 0), kwargs = {})
#   %select_scatter_default_52 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_26, %index_put_26, 0, 13), kwargs = {})
#   %select_scatter_default_53 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_51, %select_scatter_default_52, 0, 0), kwargs = {})
#   %select_241 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_53, 0, 1), kwargs = {})
#   %select_242 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_241, 0, 13), kwargs = {})
#   %view_184 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_93, [1, 1, 2, 128]), kwargs = {})
#   %permute_161 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_184, [0, 2, 1, 3]), kwargs = {})
#   %view_186 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_161, [2, 1, 128]), kwargs = {})
#   %index_put_27 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_242, [None, None, %arg1_1], %view_186), kwargs = {})
#   return %index_put_27
triton_poi_fused_index_put_select_transpose_view_78 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_78', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_78', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_78(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (27262976 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (85983232 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 13, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/uo/cuozpsivys2d2ltkwpcn4ffx7hnub5utrlxowflnevcwminrmg3u.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf365 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf365]
#   %buf362 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf362]
#   %select_scatter_default_51 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_51]
#   %select_int_26 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_51, 0, 0), kwargs = {})
#   %select_scatter_default_52 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_26, %index_put_26, 0, 13), kwargs = {})
#   %select_scatter_default_53 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_51, %select_scatter_default_52, 0, 0), kwargs = {})
#   %select_int_27 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_53, 0, 1), kwargs = {})
#   %select_scatter_default_54 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_27, %index_put_27, 0, 13), kwargs = {})
#   %select_scatter_default_55 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_53, %select_scatter_default_54, 0, 1), kwargs = {})
#   return %select_scatter_default_55
triton_poi_fused_79 = async_compile.triton('triton_poi_fused_79', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_79', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_79(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 13, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vr/cvrb235uqbfygmld3j6tzjcedub2jjpnpootlo4hv7naujngpj65.py
# Topologically Sorted Source Nodes: [attn_output_52], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_52 => clone_26, convert_element_type_390, expand_78, mul_150, permute_162, select_246, select_247, unsqueeze_26, view_187
# Graph fragment:
#   %select_scatter_default_55 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_55]
#   %select_246 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_55, 0, 0), kwargs = {})
#   %select_247 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_246, 0, 13), kwargs = {})
#   %convert_element_type_390 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_247, torch.float32), kwargs = {})
#   %unsqueeze_26 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_390, 2), kwargs = {})
#   %expand_78 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_26, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_26 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_78,), kwargs = {memory_format: torch.contiguous_format})
#   %view_187 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_26, [1, 16, 8192, 128]), kwargs = {})
#   %permute_162 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_187, [0, 1, 3, 2]), kwargs = {})
#   %mul_150 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_162, 0.29730177875068026), kwargs = {})
#   return %expand_81
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_80 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_80', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_80', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_80(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (27262976 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/z3/cz35fdhyfctzxqh3ps63ufw5hai2f6xpbtj5j7n73hf2jd4rbu3s.py
# Topologically Sorted Source Nodes: [setitem_27, attn_output_52], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_52 => clone_27, convert_element_type_391, expand_79, unsqueeze_27
#   setitem_27 => select_244, select_245
# Graph fragment:
#   %select_scatter_default_55 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_55]
#   %select_244 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_55, 0, 1), kwargs = {})
#   %select_245 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_244, 0, 13), kwargs = {})
#   %convert_element_type_391 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_245, torch.float32), kwargs = {})
#   %unsqueeze_27 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_391, 2), kwargs = {})
#   %expand_79 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_27, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_27 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_79,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_27
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_81 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_81', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_81', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_81(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (85983232 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/tc/ctcuudc7f53u6isolyi54iittvtvmjpppd7mhjtmxjnq354i2vbt.py
# Topologically Sorted Source Nodes: [cos, sin, view_43, key_states_43, k_14, chunk_29, setitem_28, mul_130, neg_29, cat_29, mul_131, k_embed_14, key_states_44], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_29 => cat_29
#   chunk_29 => split_29
#   cos => index
#   k_14 => convert_element_type_415
#   k_embed_14 => add_114
#   key_states_43 => permute_172
#   key_states_44 => convert_element_type_417
#   mul_130 => mul_158
#   mul_131 => mul_159
#   neg_29 => neg_43
#   setitem_28 => index_put_28, select_254, select_255, view_199
#   sin => index_1
#   view_43 => view_197
# Graph fragment:
#   %select_scatter_default_55 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_55]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_197 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_99, [1, 1, 2, 128]), kwargs = {})
#   %permute_172 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_197, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_415 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_172, torch.float32), kwargs = {})
#   %split_29 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_415, 64, -1), kwargs = {})
#   %select_254 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_55, 0, 0), kwargs = {})
#   %select_255 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_254, 0, 14), kwargs = {})
#   %mul_158 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_415, %index), kwargs = {})
#   %neg_43 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_59,), kwargs = {})
#   %cat_29 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_43, %getitem_58], -1), kwargs = {})
#   %mul_159 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_29, %index_1), kwargs = {})
#   %add_114 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_158, %mul_159), kwargs = {})
#   %convert_element_type_417 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_114, torch.bfloat16), kwargs = {})
#   %view_199 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_417, [2, 1, 128]), kwargs = {})
#   %index_put_28 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_255, [None, None, %arg1_1], %view_199), kwargs = {})
#   return %index_put_28
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_82 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_82', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_82', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_82(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (29360128 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/lz/clz5lgyuyqpn7xvw7dxdmpqrovu5lpgkk4vub6wac7fnhs53brih.py
# Topologically Sorted Source Nodes: [setitem_29, view_44, value_states_29], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_29 => index_put_29, select_259, select_260, view_200
#   value_states_29 => permute_173
#   view_44 => view_198
# Graph fragment:
#   %buf390 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf390]
#   %select_scatter_default_55 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_55]
#   %select_int_28 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_55, 0, 0), kwargs = {})
#   %select_scatter_default_56 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_28, %index_put_28, 0, 14), kwargs = {})
#   %select_scatter_default_57 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_55, %select_scatter_default_56, 0, 0), kwargs = {})
#   %select_259 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_57, 0, 1), kwargs = {})
#   %select_260 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_259, 0, 14), kwargs = {})
#   %view_198 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_100, [1, 1, 2, 128]), kwargs = {})
#   %permute_173 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_198, [0, 2, 1, 3]), kwargs = {})
#   %view_200 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_173, [2, 1, 128]), kwargs = {})
#   %index_put_29 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_260, [None, None, %arg1_1], %view_200), kwargs = {})
#   return %index_put_29
triton_poi_fused_index_put_select_transpose_view_83 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_83', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_83', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_83(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (29360128 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (88080384 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 14, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/fh/cfhupspluqxbojkpyqddedy7eg4to2jqyr2khywhsqawg5fcc3zn.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf393 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf393]
#   %buf390 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf390]
#   %select_scatter_default_55 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_55]
#   %select_int_28 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_55, 0, 0), kwargs = {})
#   %select_scatter_default_56 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_28, %index_put_28, 0, 14), kwargs = {})
#   %select_scatter_default_57 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_55, %select_scatter_default_56, 0, 0), kwargs = {})
#   %select_int_29 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_57, 0, 1), kwargs = {})
#   %select_scatter_default_58 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_29, %index_put_29, 0, 14), kwargs = {})
#   %select_scatter_default_59 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_57, %select_scatter_default_58, 0, 1), kwargs = {})
#   return %select_scatter_default_59
triton_poi_fused_84 = async_compile.triton('triton_poi_fused_84', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_84', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_84(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 14, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5z/c5z7iopqquzhqrfj7rfwpzidz5wj5lqnbsqprzjpx6mq6w37cpgu.py
# Topologically Sorted Source Nodes: [attn_output_56], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_56 => clone_28, convert_element_type_419, expand_84, mul_161, permute_174, select_264, select_265, unsqueeze_28, view_201
# Graph fragment:
#   %select_scatter_default_59 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_59]
#   %select_264 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_59, 0, 0), kwargs = {})
#   %select_265 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_264, 0, 14), kwargs = {})
#   %convert_element_type_419 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_265, torch.float32), kwargs = {})
#   %unsqueeze_28 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_419, 2), kwargs = {})
#   %expand_84 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_28, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_28 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_84,), kwargs = {memory_format: torch.contiguous_format})
#   %view_201 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_28, [1, 16, 8192, 128]), kwargs = {})
#   %permute_174 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_201, [0, 1, 3, 2]), kwargs = {})
#   %mul_161 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_174, 0.29730177875068026), kwargs = {})
#   return %expand_87
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_85 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_85', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_85', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_85(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (29360128 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/xz/cxzjokis23sbqzf3zrqelvbmwny6wsp5fvnhpeo4gm5t2o6z4vnr.py
# Topologically Sorted Source Nodes: [setitem_29, attn_output_56], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_56 => clone_29, convert_element_type_420, expand_85, unsqueeze_29
#   setitem_29 => select_262, select_263
# Graph fragment:
#   %select_scatter_default_59 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_59]
#   %select_262 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_59, 0, 1), kwargs = {})
#   %select_263 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_262, 0, 14), kwargs = {})
#   %convert_element_type_420 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_263, torch.float32), kwargs = {})
#   %unsqueeze_29 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_420, 2), kwargs = {})
#   %expand_85 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_29, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_29 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_85,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_29
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_86 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_86', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_86', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_86(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (88080384 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vg/cvg7zc4ufeoocjwiq2fbv3kqzkjnybxlfubbpezyzvh4aqwyo3nb.py
# Topologically Sorted Source Nodes: [cos, sin, view_46, key_states_46, k_15, chunk_31, setitem_30, mul_139, neg_31, cat_31, mul_140, k_embed_15, key_states_47], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_31 => cat_31
#   chunk_31 => split_31
#   cos => index
#   k_15 => convert_element_type_444
#   k_embed_15 => add_122
#   key_states_46 => permute_184
#   key_states_47 => convert_element_type_446
#   mul_139 => mul_169
#   mul_140 => mul_170
#   neg_31 => neg_46
#   setitem_30 => index_put_30, select_272, select_273, view_213
#   sin => index_1
#   view_46 => view_211
# Graph fragment:
#   %select_scatter_default_59 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_59]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_211 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_106, [1, 1, 2, 128]), kwargs = {})
#   %permute_184 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_211, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_444 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_184, torch.float32), kwargs = {})
#   %split_31 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_444, 64, -1), kwargs = {})
#   %select_272 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_59, 0, 0), kwargs = {})
#   %select_273 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_272, 0, 15), kwargs = {})
#   %mul_169 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %index), kwargs = {})
#   %neg_46 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_63,), kwargs = {})
#   %cat_31 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_46, %getitem_62], -1), kwargs = {})
#   %mul_170 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %index_1), kwargs = {})
#   %add_122 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_169, %mul_170), kwargs = {})
#   %convert_element_type_446 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_122, torch.bfloat16), kwargs = {})
#   %view_213 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_446, [2, 1, 128]), kwargs = {})
#   %index_put_30 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_273, [None, None, %arg1_1], %view_213), kwargs = {})
#   return %index_put_30
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_87 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_87', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_87', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_87(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (31457280 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7o/c7oe2by7ongqhyndszswqsiofb52b3fql7uphjgvmzy73mu52vz4.py
# Topologically Sorted Source Nodes: [setitem_31, view_47, value_states_31], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_31 => index_put_31, select_277, select_278, view_214
#   value_states_31 => permute_185
#   view_47 => view_212
# Graph fragment:
#   %buf417 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf417]
#   %select_scatter_default_59 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_59]
#   %select_int_30 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_59, 0, 0), kwargs = {})
#   %select_scatter_default_60 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_30, %index_put_30, 0, 15), kwargs = {})
#   %select_scatter_default_61 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_59, %select_scatter_default_60, 0, 0), kwargs = {})
#   %select_277 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_61, 0, 1), kwargs = {})
#   %select_278 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_277, 0, 15), kwargs = {})
#   %view_212 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_107, [1, 1, 2, 128]), kwargs = {})
#   %permute_185 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_212, [0, 2, 1, 3]), kwargs = {})
#   %view_214 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_185, [2, 1, 128]), kwargs = {})
#   %index_put_31 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_278, [None, None, %arg1_1], %view_214), kwargs = {})
#   return %index_put_31
triton_poi_fused_index_put_select_transpose_view_88 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_88', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_88', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_88(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (31457280 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (90177536 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 15, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/g4/cg4tn6lzkqxxsqekjpbsyjonyqhkgfjtm6ciiggtiwsdtayeqap5.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf420 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf420]
#   %buf417 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf417]
#   %select_scatter_default_59 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_59]
#   %select_int_30 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_59, 0, 0), kwargs = {})
#   %select_scatter_default_60 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_30, %index_put_30, 0, 15), kwargs = {})
#   %select_scatter_default_61 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_59, %select_scatter_default_60, 0, 0), kwargs = {})
#   %select_int_31 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_61, 0, 1), kwargs = {})
#   %select_scatter_default_62 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_31, %index_put_31, 0, 15), kwargs = {})
#   %select_scatter_default_63 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_61, %select_scatter_default_62, 0, 1), kwargs = {})
#   return %select_scatter_default_63
triton_poi_fused_89 = async_compile.triton('triton_poi_fused_89', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_89', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_89(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 15, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/2e/c2et3nzbcfxw7a7mk3ra4skzgg6d5su3q7j4535zlcdpgqsrl3zm.py
# Topologically Sorted Source Nodes: [attn_output_60], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_60 => clone_30, convert_element_type_448, expand_90, mul_172, permute_186, select_282, select_283, unsqueeze_30, view_215
# Graph fragment:
#   %select_scatter_default_63 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_63]
#   %select_282 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_63, 0, 0), kwargs = {})
#   %select_283 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_282, 0, 15), kwargs = {})
#   %convert_element_type_448 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_283, torch.float32), kwargs = {})
#   %unsqueeze_30 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_448, 2), kwargs = {})
#   %expand_90 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_30, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_30 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_90,), kwargs = {memory_format: torch.contiguous_format})
#   %view_215 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_30, [1, 16, 8192, 128]), kwargs = {})
#   %permute_186 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_215, [0, 1, 3, 2]), kwargs = {})
#   %mul_172 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_186, 0.29730177875068026), kwargs = {})
#   return %expand_93
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_90 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_90', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_90', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_90(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (31457280 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/mx/cmxvgb2tlbd7gluwk524mvycoulkhd4qybyw4qn3iyylzkrgwphy.py
# Topologically Sorted Source Nodes: [setitem_31, attn_output_60], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_60 => clone_31, convert_element_type_449, expand_91, unsqueeze_31
#   setitem_31 => select_280, select_281
# Graph fragment:
#   %select_scatter_default_63 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_63]
#   %select_280 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_63, 0, 1), kwargs = {})
#   %select_281 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_280, 0, 15), kwargs = {})
#   %convert_element_type_449 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_281, torch.float32), kwargs = {})
#   %unsqueeze_31 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_449, 2), kwargs = {})
#   %expand_91 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_31, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_31 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_91,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_31
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_91 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_91', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_91', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_91(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (90177536 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/uk/cuk5aiblosjva5f3culo6rwtjszthiohsvzah3mdknsa4ne3pr6w.py
# Topologically Sorted Source Nodes: [cos, sin, view_49, key_states_49, k_16, chunk_33, setitem_32, mul_148, neg_33, cat_33, mul_149, k_embed_16, key_states_50], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_33 => cat_33
#   chunk_33 => split_33
#   cos => index
#   k_16 => convert_element_type_473
#   k_embed_16 => add_130
#   key_states_49 => permute_196
#   key_states_50 => convert_element_type_475
#   mul_148 => mul_180
#   mul_149 => mul_181
#   neg_33 => neg_49
#   setitem_32 => index_put_32, select_290, select_291, view_227
#   sin => index_1
#   view_49 => view_225
# Graph fragment:
#   %select_scatter_default_63 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_63]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_225 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_113, [1, 1, 2, 128]), kwargs = {})
#   %permute_196 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_225, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_473 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_196, torch.float32), kwargs = {})
#   %split_33 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_473, 64, -1), kwargs = {})
#   %select_290 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_63, 0, 0), kwargs = {})
#   %select_291 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_290, 0, 16), kwargs = {})
#   %mul_180 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_473, %index), kwargs = {})
#   %neg_49 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_67,), kwargs = {})
#   %cat_33 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_49, %getitem_66], -1), kwargs = {})
#   %mul_181 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_33, %index_1), kwargs = {})
#   %add_130 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_180, %mul_181), kwargs = {})
#   %convert_element_type_475 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_130, torch.bfloat16), kwargs = {})
#   %view_227 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_475, [2, 1, 128]), kwargs = {})
#   %index_put_32 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_291, [None, None, %arg1_1], %view_227), kwargs = {})
#   return %index_put_32
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_92 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_92', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_92', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_92(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (33554432 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/jy/cjyvzi3gnfl3zbrot4mrqr2tmhbuiigvdjydufgevs7yegp77oox.py
# Topologically Sorted Source Nodes: [setitem_33, view_50, value_states_33], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_33 => index_put_33, select_295, select_296, view_228
#   value_states_33 => permute_197
#   view_50 => view_226
# Graph fragment:
#   %buf445 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf445]
#   %select_scatter_default_63 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_63]
#   %select_int_32 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_63, 0, 0), kwargs = {})
#   %select_scatter_default_64 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_32, %index_put_32, 0, 16), kwargs = {})
#   %select_scatter_default_65 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_63, %select_scatter_default_64, 0, 0), kwargs = {})
#   %select_295 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_65, 0, 1), kwargs = {})
#   %select_296 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_295, 0, 16), kwargs = {})
#   %view_226 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_114, [1, 1, 2, 128]), kwargs = {})
#   %permute_197 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_226, [0, 2, 1, 3]), kwargs = {})
#   %view_228 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_197, [2, 1, 128]), kwargs = {})
#   %index_put_33 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_296, [None, None, %arg1_1], %view_228), kwargs = {})
#   return %index_put_33
triton_poi_fused_index_put_select_transpose_view_93 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_93', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_93', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_93(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (33554432 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (92274688 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 16, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/by/cby5xuphktkidvplxbhpbusvhmtyxlmzop2zsv7e3zt7ynp5h6j2.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf448 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf448]
#   %buf445 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf445]
#   %select_scatter_default_63 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_63]
#   %select_int_32 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_63, 0, 0), kwargs = {})
#   %select_scatter_default_64 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_32, %index_put_32, 0, 16), kwargs = {})
#   %select_scatter_default_65 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_63, %select_scatter_default_64, 0, 0), kwargs = {})
#   %select_int_33 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_65, 0, 1), kwargs = {})
#   %select_scatter_default_66 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_33, %index_put_33, 0, 16), kwargs = {})
#   %select_scatter_default_67 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_65, %select_scatter_default_66, 0, 1), kwargs = {})
#   return %select_scatter_default_67
triton_poi_fused_94 = async_compile.triton('triton_poi_fused_94', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_94', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_94(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 16, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/sd/csdra36nj72cso7kqcetdxnafctttywvfqrmm7atatmvbf2rf3q3.py
# Topologically Sorted Source Nodes: [attn_output_64], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_64 => clone_32, convert_element_type_477, expand_96, mul_183, permute_198, select_300, select_301, unsqueeze_32, view_229
# Graph fragment:
#   %select_scatter_default_67 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_67]
#   %select_300 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_67, 0, 0), kwargs = {})
#   %select_301 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_300, 0, 16), kwargs = {})
#   %convert_element_type_477 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_301, torch.float32), kwargs = {})
#   %unsqueeze_32 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_477, 2), kwargs = {})
#   %expand_96 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_32, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_32 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_96,), kwargs = {memory_format: torch.contiguous_format})
#   %view_229 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_32, [1, 16, 8192, 128]), kwargs = {})
#   %permute_198 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_229, [0, 1, 3, 2]), kwargs = {})
#   %mul_183 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_198, 0.29730177875068026), kwargs = {})
#   return %expand_99
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_95 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_95', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_95', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_95(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (33554432 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/cg/ccgce4wx2evb7k42lhfdvnucg6xdme5p7qwtkeu555kfhck5r5hz.py
# Topologically Sorted Source Nodes: [setitem_33, attn_output_64], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_64 => clone_33, convert_element_type_478, expand_97, unsqueeze_33
#   setitem_33 => select_298, select_299
# Graph fragment:
#   %select_scatter_default_67 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_67]
#   %select_298 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_67, 0, 1), kwargs = {})
#   %select_299 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_298, 0, 16), kwargs = {})
#   %convert_element_type_478 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_299, torch.float32), kwargs = {})
#   %unsqueeze_33 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_478, 2), kwargs = {})
#   %expand_97 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_33, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_33 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_97,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_33
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_96 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_96', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_96', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_96(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (92274688 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/k3/ck3qnevlwt7oyqvol6j2bc4upn5tne3rknhzuo5pmadllp4bgrhw.py
# Topologically Sorted Source Nodes: [cos, sin, view_52, key_states_52, k_17, chunk_35, setitem_34, mul_157, neg_35, cat_35, mul_158, k_embed_17, key_states_53], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_35 => cat_35
#   chunk_35 => split_35
#   cos => index
#   k_17 => convert_element_type_502
#   k_embed_17 => add_138
#   key_states_52 => permute_208
#   key_states_53 => convert_element_type_504
#   mul_157 => mul_191
#   mul_158 => mul_192
#   neg_35 => neg_52
#   setitem_34 => index_put_34, select_308, select_309, view_241
#   sin => index_1
#   view_52 => view_239
# Graph fragment:
#   %select_scatter_default_67 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_67]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_239 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_120, [1, 1, 2, 128]), kwargs = {})
#   %permute_208 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_239, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_502 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_208, torch.float32), kwargs = {})
#   %split_35 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_502, 64, -1), kwargs = {})
#   %select_308 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_67, 0, 0), kwargs = {})
#   %select_309 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_308, 0, 17), kwargs = {})
#   %mul_191 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_502, %index), kwargs = {})
#   %neg_52 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_71,), kwargs = {})
#   %cat_35 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_52, %getitem_70], -1), kwargs = {})
#   %mul_192 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_35, %index_1), kwargs = {})
#   %add_138 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_191, %mul_192), kwargs = {})
#   %convert_element_type_504 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_138, torch.bfloat16), kwargs = {})
#   %view_241 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_504, [2, 1, 128]), kwargs = {})
#   %index_put_34 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_309, [None, None, %arg1_1], %view_241), kwargs = {})
#   return %index_put_34
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_97 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_97', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_97', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_97(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (35651584 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/hs/chsvi45ug32uwlubuh7zlem3dikbgbixnz5n5qqj3wdb2bfvv2ca.py
# Topologically Sorted Source Nodes: [setitem_35, view_53, value_states_35], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_35 => index_put_35, select_313, select_314, view_242
#   value_states_35 => permute_209
#   view_53 => view_240
# Graph fragment:
#   %buf472 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf472]
#   %select_scatter_default_67 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_67]
#   %select_int_34 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_67, 0, 0), kwargs = {})
#   %select_scatter_default_68 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_34, %index_put_34, 0, 17), kwargs = {})
#   %select_scatter_default_69 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_67, %select_scatter_default_68, 0, 0), kwargs = {})
#   %select_313 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_69, 0, 1), kwargs = {})
#   %select_314 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_313, 0, 17), kwargs = {})
#   %view_240 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_121, [1, 1, 2, 128]), kwargs = {})
#   %permute_209 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_240, [0, 2, 1, 3]), kwargs = {})
#   %view_242 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_209, [2, 1, 128]), kwargs = {})
#   %index_put_35 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_314, [None, None, %arg1_1], %view_242), kwargs = {})
#   return %index_put_35
triton_poi_fused_index_put_select_transpose_view_98 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_98', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_98', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_98(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (35651584 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (94371840 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 17, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4r/c4r6lnixzqllm5e5vkvyympqoflphzzbdpnm3l2n6fe7qhmcppxi.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf475 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf475]
#   %buf472 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf472]
#   %select_scatter_default_67 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_67]
#   %select_int_34 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_67, 0, 0), kwargs = {})
#   %select_scatter_default_68 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_34, %index_put_34, 0, 17), kwargs = {})
#   %select_scatter_default_69 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_67, %select_scatter_default_68, 0, 0), kwargs = {})
#   %select_int_35 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_69, 0, 1), kwargs = {})
#   %select_scatter_default_70 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_35, %index_put_35, 0, 17), kwargs = {})
#   %select_scatter_default_71 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_69, %select_scatter_default_70, 0, 1), kwargs = {})
#   return %select_scatter_default_71
triton_poi_fused_99 = async_compile.triton('triton_poi_fused_99', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_99', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_99(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 17, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/up/cupdxpdqi2c5q5zucobxcbtvsciqsypvbhef3b7lmqynxv2pqps6.py
# Topologically Sorted Source Nodes: [attn_output_68], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_68 => clone_34, convert_element_type_506, expand_102, mul_194, permute_210, select_318, select_319, unsqueeze_34, view_243
# Graph fragment:
#   %select_scatter_default_71 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_71]
#   %select_318 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_71, 0, 0), kwargs = {})
#   %select_319 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_318, 0, 17), kwargs = {})
#   %convert_element_type_506 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_319, torch.float32), kwargs = {})
#   %unsqueeze_34 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_506, 2), kwargs = {})
#   %expand_102 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_34, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_34 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_102,), kwargs = {memory_format: torch.contiguous_format})
#   %view_243 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_34, [1, 16, 8192, 128]), kwargs = {})
#   %permute_210 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_243, [0, 1, 3, 2]), kwargs = {})
#   %mul_194 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_210, 0.29730177875068026), kwargs = {})
#   return %expand_105
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_100 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_100', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_100', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_100(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (35651584 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/kg/ckghv6ewjf6mplrn7p2za2oumb2wgwlynehzqm5toefjq7qx6wxl.py
# Topologically Sorted Source Nodes: [setitem_35, attn_output_68], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_68 => clone_35, convert_element_type_507, expand_103, unsqueeze_35
#   setitem_35 => select_316, select_317
# Graph fragment:
#   %select_scatter_default_71 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_71]
#   %select_316 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_71, 0, 1), kwargs = {})
#   %select_317 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_316, 0, 17), kwargs = {})
#   %convert_element_type_507 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_317, torch.float32), kwargs = {})
#   %unsqueeze_35 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_507, 2), kwargs = {})
#   %expand_103 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_35, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_35 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_103,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_35
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_101 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_101', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_101', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_101(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (94371840 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/cu/cculcxmaibfyodegixbaj7jylsn2ypncpxysxd2yq4o3yso354qu.py
# Topologically Sorted Source Nodes: [cos, sin, view_55, key_states_55, k_18, chunk_37, setitem_36, mul_166, neg_37, cat_37, mul_167, k_embed_18, key_states_56], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_37 => cat_37
#   chunk_37 => split_37
#   cos => index
#   k_18 => convert_element_type_531
#   k_embed_18 => add_146
#   key_states_55 => permute_220
#   key_states_56 => convert_element_type_533
#   mul_166 => mul_202
#   mul_167 => mul_203
#   neg_37 => neg_55
#   setitem_36 => index_put_36, select_326, select_327, view_255
#   sin => index_1
#   view_55 => view_253
# Graph fragment:
#   %select_scatter_default_71 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_71]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_253 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_127, [1, 1, 2, 128]), kwargs = {})
#   %permute_220 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_253, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_531 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_220, torch.float32), kwargs = {})
#   %split_37 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_531, 64, -1), kwargs = {})
#   %select_326 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_71, 0, 0), kwargs = {})
#   %select_327 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_326, 0, 18), kwargs = {})
#   %mul_202 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_531, %index), kwargs = {})
#   %neg_55 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_75,), kwargs = {})
#   %cat_37 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_55, %getitem_74], -1), kwargs = {})
#   %mul_203 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_37, %index_1), kwargs = {})
#   %add_146 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_202, %mul_203), kwargs = {})
#   %convert_element_type_533 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_146, torch.bfloat16), kwargs = {})
#   %view_255 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_533, [2, 1, 128]), kwargs = {})
#   %index_put_36 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_327, [None, None, %arg1_1], %view_255), kwargs = {})
#   return %index_put_36
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_102 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_102', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_102', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_102(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (37748736 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/kp/ckp5xrso4nw3idrkzdjjpsehybu6baxn6oyhpxb7ai26wflddt7h.py
# Topologically Sorted Source Nodes: [setitem_37, view_56, value_states_37], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_37 => index_put_37, select_331, select_332, view_256
#   value_states_37 => permute_221
#   view_56 => view_254
# Graph fragment:
#   %buf500 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf500]
#   %select_scatter_default_71 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_71]
#   %select_int_36 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_71, 0, 0), kwargs = {})
#   %select_scatter_default_72 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_36, %index_put_36, 0, 18), kwargs = {})
#   %select_scatter_default_73 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_71, %select_scatter_default_72, 0, 0), kwargs = {})
#   %select_331 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_73, 0, 1), kwargs = {})
#   %select_332 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_331, 0, 18), kwargs = {})
#   %view_254 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_128, [1, 1, 2, 128]), kwargs = {})
#   %permute_221 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_254, [0, 2, 1, 3]), kwargs = {})
#   %view_256 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_221, [2, 1, 128]), kwargs = {})
#   %index_put_37 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_332, [None, None, %arg1_1], %view_256), kwargs = {})
#   return %index_put_37
triton_poi_fused_index_put_select_transpose_view_103 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_103', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_103', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_103(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (37748736 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (96468992 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 18, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/pq/cpq5t6nymskw45ddhngm27qidxaqeg7tceqydpj7aygdae7y4znh.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf503 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf503]
#   %buf500 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf500]
#   %select_scatter_default_71 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_71]
#   %select_int_36 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_71, 0, 0), kwargs = {})
#   %select_scatter_default_72 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_36, %index_put_36, 0, 18), kwargs = {})
#   %select_scatter_default_73 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_71, %select_scatter_default_72, 0, 0), kwargs = {})
#   %select_int_37 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_73, 0, 1), kwargs = {})
#   %select_scatter_default_74 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_37, %index_put_37, 0, 18), kwargs = {})
#   %select_scatter_default_75 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_73, %select_scatter_default_74, 0, 1), kwargs = {})
#   return %select_scatter_default_75
triton_poi_fused_104 = async_compile.triton('triton_poi_fused_104', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_104', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_104(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 18, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4d/c4d7lgvsgl2ddhfty32wv5meioitwyzjgq6d75l3f2iytoeaggua.py
# Topologically Sorted Source Nodes: [attn_output_72], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_72 => clone_36, convert_element_type_535, expand_108, mul_205, permute_222, select_336, select_337, unsqueeze_36, view_257
# Graph fragment:
#   %select_scatter_default_75 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_75]
#   %select_336 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_75, 0, 0), kwargs = {})
#   %select_337 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_336, 0, 18), kwargs = {})
#   %convert_element_type_535 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_337, torch.float32), kwargs = {})
#   %unsqueeze_36 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_535, 2), kwargs = {})
#   %expand_108 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_36, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_36 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_108,), kwargs = {memory_format: torch.contiguous_format})
#   %view_257 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_36, [1, 16, 8192, 128]), kwargs = {})
#   %permute_222 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_257, [0, 1, 3, 2]), kwargs = {})
#   %mul_205 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_222, 0.29730177875068026), kwargs = {})
#   return %expand_111
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_105 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_105', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_105', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_105(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (37748736 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/al/calgpuv2k6a3vwk2avncg6lfgoe7igv5iazrecux55fxppvk6n7x.py
# Topologically Sorted Source Nodes: [setitem_37, attn_output_72], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_72 => clone_37, convert_element_type_536, expand_109, unsqueeze_37
#   setitem_37 => select_334, select_335
# Graph fragment:
#   %select_scatter_default_75 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_75]
#   %select_334 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_75, 0, 1), kwargs = {})
#   %select_335 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_334, 0, 18), kwargs = {})
#   %convert_element_type_536 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_335, torch.float32), kwargs = {})
#   %unsqueeze_37 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_536, 2), kwargs = {})
#   %expand_109 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_37, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_37 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_109,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_37
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_106 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_106', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_106', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_106(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (96468992 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/fn/cfn7npozv2q3pr3nbatj4bj2lqywg5z3ggwnnfyql2v5xzkeh7dq.py
# Topologically Sorted Source Nodes: [cos, sin, view_58, key_states_58, k_19, chunk_39, setitem_38, mul_175, neg_39, cat_39, mul_176, k_embed_19, key_states_59], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_39 => cat_39
#   chunk_39 => split_39
#   cos => index
#   k_19 => convert_element_type_560
#   k_embed_19 => add_154
#   key_states_58 => permute_232
#   key_states_59 => convert_element_type_562
#   mul_175 => mul_213
#   mul_176 => mul_214
#   neg_39 => neg_58
#   setitem_38 => index_put_38, select_344, select_345, view_269
#   sin => index_1
#   view_58 => view_267
# Graph fragment:
#   %select_scatter_default_75 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_75]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_267 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_134, [1, 1, 2, 128]), kwargs = {})
#   %permute_232 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_267, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_560 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_232, torch.float32), kwargs = {})
#   %split_39 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_560, 64, -1), kwargs = {})
#   %select_344 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_75, 0, 0), kwargs = {})
#   %select_345 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_344, 0, 19), kwargs = {})
#   %mul_213 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_560, %index), kwargs = {})
#   %neg_58 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_79,), kwargs = {})
#   %cat_39 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_58, %getitem_78], -1), kwargs = {})
#   %mul_214 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_39, %index_1), kwargs = {})
#   %add_154 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_213, %mul_214), kwargs = {})
#   %convert_element_type_562 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_154, torch.bfloat16), kwargs = {})
#   %view_269 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_562, [2, 1, 128]), kwargs = {})
#   %index_put_38 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_345, [None, None, %arg1_1], %view_269), kwargs = {})
#   return %index_put_38
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_107 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_107', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_107', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_107(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (39845888 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ne/cnefesbfv52cgqj6ly43dngrcwpko7bqjqyee4bp33vxodhmvxbm.py
# Topologically Sorted Source Nodes: [setitem_39, view_59, value_states_39], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_39 => index_put_39, select_349, select_350, view_270
#   value_states_39 => permute_233
#   view_59 => view_268
# Graph fragment:
#   %buf527 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf527]
#   %select_scatter_default_75 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_75]
#   %select_int_38 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_75, 0, 0), kwargs = {})
#   %select_scatter_default_76 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_38, %index_put_38, 0, 19), kwargs = {})
#   %select_scatter_default_77 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_75, %select_scatter_default_76, 0, 0), kwargs = {})
#   %select_349 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_77, 0, 1), kwargs = {})
#   %select_350 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_349, 0, 19), kwargs = {})
#   %view_268 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_135, [1, 1, 2, 128]), kwargs = {})
#   %permute_233 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_268, [0, 2, 1, 3]), kwargs = {})
#   %view_270 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_233, [2, 1, 128]), kwargs = {})
#   %index_put_39 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_350, [None, None, %arg1_1], %view_270), kwargs = {})
#   return %index_put_39
triton_poi_fused_index_put_select_transpose_view_108 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_108', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_108', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_108(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (39845888 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (98566144 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 19, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/q2/cq2agq7rudkgc6hta6hftwxysmpfjtoebo2qmbuyoncf2hxp5j6o.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf530 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf530]
#   %buf527 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf527]
#   %select_scatter_default_75 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_75]
#   %select_int_38 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_75, 0, 0), kwargs = {})
#   %select_scatter_default_76 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_38, %index_put_38, 0, 19), kwargs = {})
#   %select_scatter_default_77 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_75, %select_scatter_default_76, 0, 0), kwargs = {})
#   %select_int_39 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_77, 0, 1), kwargs = {})
#   %select_scatter_default_78 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_39, %index_put_39, 0, 19), kwargs = {})
#   %select_scatter_default_79 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_77, %select_scatter_default_78, 0, 1), kwargs = {})
#   return %select_scatter_default_79
triton_poi_fused_109 = async_compile.triton('triton_poi_fused_109', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_109', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_109(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 19, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/56/c56fpore3r4bxmucnkz2gfjbqlj5hqr2p4sslx2fuvrsidsldo4w.py
# Topologically Sorted Source Nodes: [attn_output_76], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_76 => clone_38, convert_element_type_564, expand_114, mul_216, permute_234, select_354, select_355, unsqueeze_38, view_271
# Graph fragment:
#   %select_scatter_default_79 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_79]
#   %select_354 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_79, 0, 0), kwargs = {})
#   %select_355 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_354, 0, 19), kwargs = {})
#   %convert_element_type_564 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_355, torch.float32), kwargs = {})
#   %unsqueeze_38 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_564, 2), kwargs = {})
#   %expand_114 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_38, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_38 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_114,), kwargs = {memory_format: torch.contiguous_format})
#   %view_271 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_38, [1, 16, 8192, 128]), kwargs = {})
#   %permute_234 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_271, [0, 1, 3, 2]), kwargs = {})
#   %mul_216 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_234, 0.29730177875068026), kwargs = {})
#   return %expand_117
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_110 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_110', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_110', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_110(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (39845888 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/t7/ct7nkyzqmjqvym4hrzkxtb2qx5z6lahtrglv4zbbsw3gexxbuq46.py
# Topologically Sorted Source Nodes: [setitem_39, attn_output_76], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_76 => clone_39, convert_element_type_565, expand_115, unsqueeze_39
#   setitem_39 => select_352, select_353
# Graph fragment:
#   %select_scatter_default_79 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_79]
#   %select_352 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_79, 0, 1), kwargs = {})
#   %select_353 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_352, 0, 19), kwargs = {})
#   %convert_element_type_565 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_353, torch.float32), kwargs = {})
#   %unsqueeze_39 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_565, 2), kwargs = {})
#   %expand_115 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_39, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_39 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_115,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_39
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_111 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_111', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_111', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_111(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (98566144 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ii/cii7jd65ya2wmkepo7p7lx5kctv3u3fg5xhjjfbhkrg5xf7c3hmf.py
# Topologically Sorted Source Nodes: [cos, sin, view_61, key_states_61, k_20, chunk_41, setitem_40, mul_184, neg_41, cat_41, mul_185, k_embed_20, key_states_62], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_41 => cat_41
#   chunk_41 => split_41
#   cos => index
#   k_20 => convert_element_type_589
#   k_embed_20 => add_162
#   key_states_61 => permute_244
#   key_states_62 => convert_element_type_591
#   mul_184 => mul_224
#   mul_185 => mul_225
#   neg_41 => neg_61
#   setitem_40 => index_put_40, select_362, select_363, view_283
#   sin => index_1
#   view_61 => view_281
# Graph fragment:
#   %select_scatter_default_79 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_79]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_281 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_141, [1, 1, 2, 128]), kwargs = {})
#   %permute_244 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_281, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_589 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_244, torch.float32), kwargs = {})
#   %split_41 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_589, 64, -1), kwargs = {})
#   %select_362 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_79, 0, 0), kwargs = {})
#   %select_363 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_362, 0, 20), kwargs = {})
#   %mul_224 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_589, %index), kwargs = {})
#   %neg_61 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_83,), kwargs = {})
#   %cat_41 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_61, %getitem_82], -1), kwargs = {})
#   %mul_225 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_41, %index_1), kwargs = {})
#   %add_162 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_224, %mul_225), kwargs = {})
#   %convert_element_type_591 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_162, torch.bfloat16), kwargs = {})
#   %view_283 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_591, [2, 1, 128]), kwargs = {})
#   %index_put_40 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_363, [None, None, %arg1_1], %view_283), kwargs = {})
#   return %index_put_40
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_112 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_112', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_112', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_112(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (41943040 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/hr/chrfmy2txrc2ko2l5cf6dpllwwx2czm2xzvuhkszlm7qcom2cxpj.py
# Topologically Sorted Source Nodes: [setitem_41, view_62, value_states_41], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_41 => index_put_41, select_367, select_368, view_284
#   value_states_41 => permute_245
#   view_62 => view_282
# Graph fragment:
#   %buf555 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf555]
#   %select_scatter_default_79 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_79]
#   %select_int_40 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_79, 0, 0), kwargs = {})
#   %select_scatter_default_80 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_40, %index_put_40, 0, 20), kwargs = {})
#   %select_scatter_default_81 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_79, %select_scatter_default_80, 0, 0), kwargs = {})
#   %select_367 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_81, 0, 1), kwargs = {})
#   %select_368 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_367, 0, 20), kwargs = {})
#   %view_282 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_142, [1, 1, 2, 128]), kwargs = {})
#   %permute_245 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_282, [0, 2, 1, 3]), kwargs = {})
#   %view_284 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_245, [2, 1, 128]), kwargs = {})
#   %index_put_41 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_368, [None, None, %arg1_1], %view_284), kwargs = {})
#   return %index_put_41
triton_poi_fused_index_put_select_transpose_view_113 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_113', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_113', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_113(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (41943040 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (100663296 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 20, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/oc/cocpvqmrcmkrr34asj2b3bzylmluz6huacghlbae7whpszfesgd2.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf558 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf558]
#   %buf555 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf555]
#   %select_scatter_default_79 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_79]
#   %select_int_40 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_79, 0, 0), kwargs = {})
#   %select_scatter_default_80 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_40, %index_put_40, 0, 20), kwargs = {})
#   %select_scatter_default_81 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_79, %select_scatter_default_80, 0, 0), kwargs = {})
#   %select_int_41 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_81, 0, 1), kwargs = {})
#   %select_scatter_default_82 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_41, %index_put_41, 0, 20), kwargs = {})
#   %select_scatter_default_83 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_81, %select_scatter_default_82, 0, 1), kwargs = {})
#   return %select_scatter_default_83
triton_poi_fused_114 = async_compile.triton('triton_poi_fused_114', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_114', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_114(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 20, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/tu/ctuubtuv6k74b6ftceub37zxmcenbvosvruky47xfba7zehaprys.py
# Topologically Sorted Source Nodes: [attn_output_80], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_80 => clone_40, convert_element_type_593, expand_120, mul_227, permute_246, select_372, select_373, unsqueeze_40, view_285
# Graph fragment:
#   %select_scatter_default_83 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_83]
#   %select_372 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_83, 0, 0), kwargs = {})
#   %select_373 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_372, 0, 20), kwargs = {})
#   %convert_element_type_593 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_373, torch.float32), kwargs = {})
#   %unsqueeze_40 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_593, 2), kwargs = {})
#   %expand_120 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_40, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_40 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_120,), kwargs = {memory_format: torch.contiguous_format})
#   %view_285 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_40, [1, 16, 8192, 128]), kwargs = {})
#   %permute_246 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_285, [0, 1, 3, 2]), kwargs = {})
#   %mul_227 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_246, 0.29730177875068026), kwargs = {})
#   return %expand_123
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_115 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_115', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_115', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_115(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (41943040 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/3p/c3pgefezazzifpsolsgtldwfsltnlxq5iee6kknjp3g26bgzgmav.py
# Topologically Sorted Source Nodes: [setitem_41, attn_output_80], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_80 => clone_41, convert_element_type_594, expand_121, unsqueeze_41
#   setitem_41 => select_370, select_371
# Graph fragment:
#   %select_scatter_default_83 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_83]
#   %select_370 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_83, 0, 1), kwargs = {})
#   %select_371 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_370, 0, 20), kwargs = {})
#   %convert_element_type_594 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_371, torch.float32), kwargs = {})
#   %unsqueeze_41 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_594, 2), kwargs = {})
#   %expand_121 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_41, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_41 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_121,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_41
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_116 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_116', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_116', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_116(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (100663296 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/bz/cbzvfxzfxn3ov5ssixeceui3bxc6jz4lw535aufv3etxbsyuld7m.py
# Topologically Sorted Source Nodes: [cos, sin, view_64, key_states_64, k_21, chunk_43, setitem_42, mul_193, neg_43, cat_43, mul_194, k_embed_21, key_states_65], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_43 => cat_43
#   chunk_43 => split_43
#   cos => index
#   k_21 => convert_element_type_618
#   k_embed_21 => add_170
#   key_states_64 => permute_256
#   key_states_65 => convert_element_type_620
#   mul_193 => mul_235
#   mul_194 => mul_236
#   neg_43 => neg_64
#   setitem_42 => index_put_42, select_380, select_381, view_297
#   sin => index_1
#   view_64 => view_295
# Graph fragment:
#   %select_scatter_default_83 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_83]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_295 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_148, [1, 1, 2, 128]), kwargs = {})
#   %permute_256 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_295, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_618 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_256, torch.float32), kwargs = {})
#   %split_43 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_618, 64, -1), kwargs = {})
#   %select_380 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_83, 0, 0), kwargs = {})
#   %select_381 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_380, 0, 21), kwargs = {})
#   %mul_235 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_618, %index), kwargs = {})
#   %neg_64 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_87,), kwargs = {})
#   %cat_43 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_64, %getitem_86], -1), kwargs = {})
#   %mul_236 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_43, %index_1), kwargs = {})
#   %add_170 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_235, %mul_236), kwargs = {})
#   %convert_element_type_620 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_170, torch.bfloat16), kwargs = {})
#   %view_297 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_620, [2, 1, 128]), kwargs = {})
#   %index_put_42 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_381, [None, None, %arg1_1], %view_297), kwargs = {})
#   return %index_put_42
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_117 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_117', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_117', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_117(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (44040192 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zv/czvlwlvbuaxpaoegjqkxcx7rkl5bpzbmezvq3eqvmredooua2c3m.py
# Topologically Sorted Source Nodes: [setitem_43, view_65, value_states_43], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_43 => index_put_43, select_385, select_386, view_298
#   value_states_43 => permute_257
#   view_65 => view_296
# Graph fragment:
#   %buf582 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf582]
#   %select_scatter_default_83 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_83]
#   %select_int_42 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_83, 0, 0), kwargs = {})
#   %select_scatter_default_84 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_42, %index_put_42, 0, 21), kwargs = {})
#   %select_scatter_default_85 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_83, %select_scatter_default_84, 0, 0), kwargs = {})
#   %select_385 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_85, 0, 1), kwargs = {})
#   %select_386 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_385, 0, 21), kwargs = {})
#   %view_296 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_149, [1, 1, 2, 128]), kwargs = {})
#   %permute_257 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_296, [0, 2, 1, 3]), kwargs = {})
#   %view_298 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_257, [2, 1, 128]), kwargs = {})
#   %index_put_43 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_386, [None, None, %arg1_1], %view_298), kwargs = {})
#   return %index_put_43
triton_poi_fused_index_put_select_transpose_view_118 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_118', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_118', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_118(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (44040192 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (102760448 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 21, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/mc/cmcmvlmbhbu74yqmfuary4rcysulk6jvcmhe747qdxvcmedtnsml.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf585 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf585]
#   %buf582 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf582]
#   %select_scatter_default_83 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_83]
#   %select_int_42 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_83, 0, 0), kwargs = {})
#   %select_scatter_default_84 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_42, %index_put_42, 0, 21), kwargs = {})
#   %select_scatter_default_85 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_83, %select_scatter_default_84, 0, 0), kwargs = {})
#   %select_int_43 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_85, 0, 1), kwargs = {})
#   %select_scatter_default_86 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_43, %index_put_43, 0, 21), kwargs = {})
#   %select_scatter_default_87 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_85, %select_scatter_default_86, 0, 1), kwargs = {})
#   return %select_scatter_default_87
triton_poi_fused_119 = async_compile.triton('triton_poi_fused_119', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_119', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_119(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 21, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/gf/cgfp7lizhwmuhdsbxzwdhubenhiwucysmjntbgpqucwhwiu22s7j.py
# Topologically Sorted Source Nodes: [attn_output_84], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_84 => clone_42, convert_element_type_622, expand_126, mul_238, permute_258, select_390, select_391, unsqueeze_42, view_299
# Graph fragment:
#   %select_scatter_default_87 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_87]
#   %select_390 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_87, 0, 0), kwargs = {})
#   %select_391 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_390, 0, 21), kwargs = {})
#   %convert_element_type_622 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_391, torch.float32), kwargs = {})
#   %unsqueeze_42 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_622, 2), kwargs = {})
#   %expand_126 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_42, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_42 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_126,), kwargs = {memory_format: torch.contiguous_format})
#   %view_299 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_42, [1, 16, 8192, 128]), kwargs = {})
#   %permute_258 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_299, [0, 1, 3, 2]), kwargs = {})
#   %mul_238 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_258, 0.29730177875068026), kwargs = {})
#   return %expand_129
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_120 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_120', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_120', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_120(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (44040192 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qh/cqhpah6yvnhgj7h52nnxbp7miapktmyg3q6wymwpzxmyk63qen3k.py
# Topologically Sorted Source Nodes: [setitem_43, attn_output_84], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_84 => clone_43, convert_element_type_623, expand_127, unsqueeze_43
#   setitem_43 => select_388, select_389
# Graph fragment:
#   %select_scatter_default_87 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_87]
#   %select_388 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_87, 0, 1), kwargs = {})
#   %select_389 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_388, 0, 21), kwargs = {})
#   %convert_element_type_623 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_389, torch.float32), kwargs = {})
#   %unsqueeze_43 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_623, 2), kwargs = {})
#   %expand_127 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_43, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_43 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_127,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_43
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_121 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_121', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_121', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_121(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (102760448 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zq/czqjr3fniumdgivpn62iiusjfii5tzfzoo43eknpz7swivssgnqy.py
# Topologically Sorted Source Nodes: [cos, sin, view_67, key_states_67, k_22, chunk_45, setitem_44, mul_202, neg_45, cat_45, mul_203, k_embed_22, key_states_68], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_45 => cat_45
#   chunk_45 => split_45
#   cos => index
#   k_22 => convert_element_type_647
#   k_embed_22 => add_178
#   key_states_67 => permute_268
#   key_states_68 => convert_element_type_649
#   mul_202 => mul_246
#   mul_203 => mul_247
#   neg_45 => neg_67
#   setitem_44 => index_put_44, select_398, select_399, view_311
#   sin => index_1
#   view_67 => view_309
# Graph fragment:
#   %select_scatter_default_87 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_87]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_309 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_155, [1, 1, 2, 128]), kwargs = {})
#   %permute_268 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_309, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_647 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_268, torch.float32), kwargs = {})
#   %split_45 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_647, 64, -1), kwargs = {})
#   %select_398 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_87, 0, 0), kwargs = {})
#   %select_399 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_398, 0, 22), kwargs = {})
#   %mul_246 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_647, %index), kwargs = {})
#   %neg_67 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_91,), kwargs = {})
#   %cat_45 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_67, %getitem_90], -1), kwargs = {})
#   %mul_247 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_45, %index_1), kwargs = {})
#   %add_178 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_246, %mul_247), kwargs = {})
#   %convert_element_type_649 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_178, torch.bfloat16), kwargs = {})
#   %view_311 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_649, [2, 1, 128]), kwargs = {})
#   %index_put_44 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_399, [None, None, %arg1_1], %view_311), kwargs = {})
#   return %index_put_44
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_122 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_122', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_122', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_122(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (46137344 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/nq/cnqoa6hdovspi5q5vaqyvh3n2ayfatzj7o7r5letshfpk5bc6t54.py
# Topologically Sorted Source Nodes: [setitem_45, view_68, value_states_45], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_45 => index_put_45, select_403, select_404, view_312
#   value_states_45 => permute_269
#   view_68 => view_310
# Graph fragment:
#   %buf610 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf610]
#   %select_scatter_default_87 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_87]
#   %select_int_44 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_87, 0, 0), kwargs = {})
#   %select_scatter_default_88 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_44, %index_put_44, 0, 22), kwargs = {})
#   %select_scatter_default_89 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_87, %select_scatter_default_88, 0, 0), kwargs = {})
#   %select_403 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_89, 0, 1), kwargs = {})
#   %select_404 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_403, 0, 22), kwargs = {})
#   %view_310 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_156, [1, 1, 2, 128]), kwargs = {})
#   %permute_269 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_310, [0, 2, 1, 3]), kwargs = {})
#   %view_312 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_269, [2, 1, 128]), kwargs = {})
#   %index_put_45 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_404, [None, None, %arg1_1], %view_312), kwargs = {})
#   return %index_put_45
triton_poi_fused_index_put_select_transpose_view_123 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_123', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_123', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_123(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (46137344 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (104857600 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 22, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/hj/chj6yausfgtlrhtsegk5wcuvaudulnblqqibwiqbsnmdr452xt6o.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf613 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf613]
#   %buf610 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf610]
#   %select_scatter_default_87 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_87]
#   %select_int_44 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_87, 0, 0), kwargs = {})
#   %select_scatter_default_88 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_44, %index_put_44, 0, 22), kwargs = {})
#   %select_scatter_default_89 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_87, %select_scatter_default_88, 0, 0), kwargs = {})
#   %select_int_45 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_89, 0, 1), kwargs = {})
#   %select_scatter_default_90 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_45, %index_put_45, 0, 22), kwargs = {})
#   %select_scatter_default_91 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_89, %select_scatter_default_90, 0, 1), kwargs = {})
#   return %select_scatter_default_91
triton_poi_fused_124 = async_compile.triton('triton_poi_fused_124', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_124', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_124(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 22, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vf/cvfbgrrjpjuksfvti64z3xgbanz24rceoz2nprhz6oidjxpcz2i2.py
# Topologically Sorted Source Nodes: [attn_output_88], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_88 => clone_44, convert_element_type_651, expand_132, mul_249, permute_270, select_408, select_409, unsqueeze_44, view_313
# Graph fragment:
#   %select_scatter_default_91 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_91]
#   %select_408 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_91, 0, 0), kwargs = {})
#   %select_409 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_408, 0, 22), kwargs = {})
#   %convert_element_type_651 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_409, torch.float32), kwargs = {})
#   %unsqueeze_44 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_651, 2), kwargs = {})
#   %expand_132 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_44, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_44 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_132,), kwargs = {memory_format: torch.contiguous_format})
#   %view_313 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_44, [1, 16, 8192, 128]), kwargs = {})
#   %permute_270 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_313, [0, 1, 3, 2]), kwargs = {})
#   %mul_249 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_270, 0.29730177875068026), kwargs = {})
#   return %expand_135
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_125 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_125', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_125', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_125(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (46137344 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/gl/cglsjdk7hg5fs62ks7v4cfypt6kyk6yn7ewsjjg76lbenwrnmvn3.py
# Topologically Sorted Source Nodes: [setitem_45, attn_output_88], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_88 => clone_45, convert_element_type_652, expand_133, unsqueeze_45
#   setitem_45 => select_406, select_407
# Graph fragment:
#   %select_scatter_default_91 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_91]
#   %select_406 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_91, 0, 1), kwargs = {})
#   %select_407 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_406, 0, 22), kwargs = {})
#   %convert_element_type_652 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_407, torch.float32), kwargs = {})
#   %unsqueeze_45 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_652, 2), kwargs = {})
#   %expand_133 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_45, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_45 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_133,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_45
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_126 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_126', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_126', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_126(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (104857600 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zp/czpmpot2rmcj5ztn3coqcmlkt6zjrcg65kew6fybzuytqdft4jdh.py
# Topologically Sorted Source Nodes: [cos, sin, view_70, key_states_70, k_23, chunk_47, setitem_46, mul_211, neg_47, cat_47, mul_212, k_embed_23, key_states_71], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_47 => cat_47
#   chunk_47 => split_47
#   cos => index
#   k_23 => convert_element_type_676
#   k_embed_23 => add_186
#   key_states_70 => permute_280
#   key_states_71 => convert_element_type_678
#   mul_211 => mul_257
#   mul_212 => mul_258
#   neg_47 => neg_70
#   setitem_46 => index_put_46, select_416, select_417, view_325
#   sin => index_1
#   view_70 => view_323
# Graph fragment:
#   %select_scatter_default_91 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_91]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_323 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_162, [1, 1, 2, 128]), kwargs = {})
#   %permute_280 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_323, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_676 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_280, torch.float32), kwargs = {})
#   %split_47 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_676, 64, -1), kwargs = {})
#   %select_416 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_91, 0, 0), kwargs = {})
#   %select_417 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_416, 0, 23), kwargs = {})
#   %mul_257 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %index), kwargs = {})
#   %neg_70 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_95,), kwargs = {})
#   %cat_47 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_70, %getitem_94], -1), kwargs = {})
#   %mul_258 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_47, %index_1), kwargs = {})
#   %add_186 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_257, %mul_258), kwargs = {})
#   %convert_element_type_678 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_186, torch.bfloat16), kwargs = {})
#   %view_325 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_678, [2, 1, 128]), kwargs = {})
#   %index_put_46 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_417, [None, None, %arg1_1], %view_325), kwargs = {})
#   return %index_put_46
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_127 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_127', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_127', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_127(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (48234496 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qr/cqryk5efo7r7zaex5trfj6dxcspjwxc3xq4dy2xwwk4y7t4drfnv.py
# Topologically Sorted Source Nodes: [setitem_47, view_71, value_states_47], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_47 => index_put_47, select_421, select_422, view_326
#   value_states_47 => permute_281
#   view_71 => view_324
# Graph fragment:
#   %buf637 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf637]
#   %select_scatter_default_91 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_91]
#   %select_int_46 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_91, 0, 0), kwargs = {})
#   %select_scatter_default_92 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_46, %index_put_46, 0, 23), kwargs = {})
#   %select_scatter_default_93 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_91, %select_scatter_default_92, 0, 0), kwargs = {})
#   %select_421 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_93, 0, 1), kwargs = {})
#   %select_422 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_421, 0, 23), kwargs = {})
#   %view_324 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_163, [1, 1, 2, 128]), kwargs = {})
#   %permute_281 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_324, [0, 2, 1, 3]), kwargs = {})
#   %view_326 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_281, [2, 1, 128]), kwargs = {})
#   %index_put_47 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_422, [None, None, %arg1_1], %view_326), kwargs = {})
#   return %index_put_47
triton_poi_fused_index_put_select_transpose_view_128 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_128', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_128', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_128(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (48234496 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (106954752 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 23, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/53/c53iczimvdufyk6mb6nkuelotbqgohvw3eb5sdjteekmczx5avbl.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf640 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf640]
#   %buf637 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf637]
#   %select_scatter_default_91 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_91]
#   %select_int_46 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_91, 0, 0), kwargs = {})
#   %select_scatter_default_92 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_46, %index_put_46, 0, 23), kwargs = {})
#   %select_scatter_default_93 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_91, %select_scatter_default_92, 0, 0), kwargs = {})
#   %select_int_47 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_93, 0, 1), kwargs = {})
#   %select_scatter_default_94 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_47, %index_put_47, 0, 23), kwargs = {})
#   %select_scatter_default_95 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_93, %select_scatter_default_94, 0, 1), kwargs = {})
#   return %select_scatter_default_95
triton_poi_fused_129 = async_compile.triton('triton_poi_fused_129', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_129', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_129(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 23, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vh/cvhwectyh2dwwyjvm5btvc7pyr4acrojy4rvhn7dmxkb2vkdd7uo.py
# Topologically Sorted Source Nodes: [attn_output_92], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_92 => clone_46, convert_element_type_680, expand_138, mul_260, permute_282, select_426, select_427, unsqueeze_46, view_327
# Graph fragment:
#   %select_scatter_default_95 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_95]
#   %select_426 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_95, 0, 0), kwargs = {})
#   %select_427 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_426, 0, 23), kwargs = {})
#   %convert_element_type_680 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_427, torch.float32), kwargs = {})
#   %unsqueeze_46 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_680, 2), kwargs = {})
#   %expand_138 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_46, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_46 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_138,), kwargs = {memory_format: torch.contiguous_format})
#   %view_327 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_46, [1, 16, 8192, 128]), kwargs = {})
#   %permute_282 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_327, [0, 1, 3, 2]), kwargs = {})
#   %mul_260 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_282, 0.29730177875068026), kwargs = {})
#   return %expand_141
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_130 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_130', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_130', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_130(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (48234496 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4u/c4up4rwqzxxrx7zxyhasnug7vhwvncdd3o5eumjqdkzn34vx253j.py
# Topologically Sorted Source Nodes: [setitem_47, attn_output_92], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_92 => clone_47, convert_element_type_681, expand_139, unsqueeze_47
#   setitem_47 => select_424, select_425
# Graph fragment:
#   %select_scatter_default_95 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_95]
#   %select_424 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_95, 0, 1), kwargs = {})
#   %select_425 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_424, 0, 23), kwargs = {})
#   %convert_element_type_681 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_425, torch.float32), kwargs = {})
#   %unsqueeze_47 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_681, 2), kwargs = {})
#   %expand_139 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_47, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_47 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_139,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_47
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_131 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_131', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_131', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_131(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (106954752 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/w7/cw754m3bilrcajbdsehdv67vyabl4nygjibtcjdrogxkskhz3gvu.py
# Topologically Sorted Source Nodes: [cos, sin, view_73, key_states_73, k_24, chunk_49, setitem_48, mul_220, neg_49, cat_49, mul_221, k_embed_24, key_states_74], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_49 => cat_49
#   chunk_49 => split_49
#   cos => index
#   k_24 => convert_element_type_705
#   k_embed_24 => add_194
#   key_states_73 => permute_292
#   key_states_74 => convert_element_type_707
#   mul_220 => mul_268
#   mul_221 => mul_269
#   neg_49 => neg_73
#   setitem_48 => index_put_48, select_434, select_435, view_339
#   sin => index_1
#   view_73 => view_337
# Graph fragment:
#   %select_scatter_default_95 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_95]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_337 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_169, [1, 1, 2, 128]), kwargs = {})
#   %permute_292 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_337, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_705 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_292, torch.float32), kwargs = {})
#   %split_49 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_705, 64, -1), kwargs = {})
#   %select_434 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_95, 0, 0), kwargs = {})
#   %select_435 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_434, 0, 24), kwargs = {})
#   %mul_268 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_705, %index), kwargs = {})
#   %neg_73 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_99,), kwargs = {})
#   %cat_49 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_73, %getitem_98], -1), kwargs = {})
#   %mul_269 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_49, %index_1), kwargs = {})
#   %add_194 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_268, %mul_269), kwargs = {})
#   %convert_element_type_707 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_194, torch.bfloat16), kwargs = {})
#   %view_339 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_707, [2, 1, 128]), kwargs = {})
#   %index_put_48 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_435, [None, None, %arg1_1], %view_339), kwargs = {})
#   return %index_put_48
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_132 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_132', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_132', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_132(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (50331648 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/f7/cf73gykqwx6nlho6mfvvlbnoli2scnmcazl3nyz5m7zm6gqcckif.py
# Topologically Sorted Source Nodes: [setitem_49, view_74, value_states_49], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_49 => index_put_49, select_439, select_440, view_340
#   value_states_49 => permute_293
#   view_74 => view_338
# Graph fragment:
#   %buf665 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf665]
#   %select_scatter_default_95 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_95]
#   %select_int_48 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_95, 0, 0), kwargs = {})
#   %select_scatter_default_96 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_48, %index_put_48, 0, 24), kwargs = {})
#   %select_scatter_default_97 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_95, %select_scatter_default_96, 0, 0), kwargs = {})
#   %select_439 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_97, 0, 1), kwargs = {})
#   %select_440 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_439, 0, 24), kwargs = {})
#   %view_338 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_170, [1, 1, 2, 128]), kwargs = {})
#   %permute_293 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_338, [0, 2, 1, 3]), kwargs = {})
#   %view_340 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_293, [2, 1, 128]), kwargs = {})
#   %index_put_49 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_440, [None, None, %arg1_1], %view_340), kwargs = {})
#   return %index_put_49
triton_poi_fused_index_put_select_transpose_view_133 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_133', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_133', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_133(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (50331648 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (109051904 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 24, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vv/cvvi2tleyoikqwjenbvml4333swjgor2vvac22uadqrqx65w6yns.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf668 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf668]
#   %buf665 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf665]
#   %select_scatter_default_95 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_95]
#   %select_int_48 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_95, 0, 0), kwargs = {})
#   %select_scatter_default_96 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_48, %index_put_48, 0, 24), kwargs = {})
#   %select_scatter_default_97 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_95, %select_scatter_default_96, 0, 0), kwargs = {})
#   %select_int_49 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_97, 0, 1), kwargs = {})
#   %select_scatter_default_98 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_49, %index_put_49, 0, 24), kwargs = {})
#   %select_scatter_default_99 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_97, %select_scatter_default_98, 0, 1), kwargs = {})
#   return %select_scatter_default_99
triton_poi_fused_134 = async_compile.triton('triton_poi_fused_134', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_134', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_134(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 24, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/jv/cjvu6m7uhgfgmtyzu25z46sx6styr2jy27qzic4iiboxwdsqujup.py
# Topologically Sorted Source Nodes: [attn_output_96], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_96 => clone_48, convert_element_type_709, expand_144, mul_271, permute_294, select_444, select_445, unsqueeze_48, view_341
# Graph fragment:
#   %select_scatter_default_99 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_99]
#   %select_444 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_99, 0, 0), kwargs = {})
#   %select_445 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_444, 0, 24), kwargs = {})
#   %convert_element_type_709 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_445, torch.float32), kwargs = {})
#   %unsqueeze_48 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_709, 2), kwargs = {})
#   %expand_144 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_48, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_48 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_144,), kwargs = {memory_format: torch.contiguous_format})
#   %view_341 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_48, [1, 16, 8192, 128]), kwargs = {})
#   %permute_294 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_341, [0, 1, 3, 2]), kwargs = {})
#   %mul_271 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_294, 0.29730177875068026), kwargs = {})
#   return %expand_147
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_135 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_135', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_135', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_135(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (50331648 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/mp/cmpx355bgsxidb7utrov422sw7ri54wigmmkww22a73rv7vpxail.py
# Topologically Sorted Source Nodes: [setitem_49, attn_output_96], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_96 => clone_49, convert_element_type_710, expand_145, unsqueeze_49
#   setitem_49 => select_442, select_443
# Graph fragment:
#   %select_scatter_default_99 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_99]
#   %select_442 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_99, 0, 1), kwargs = {})
#   %select_443 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_442, 0, 24), kwargs = {})
#   %convert_element_type_710 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_443, torch.float32), kwargs = {})
#   %unsqueeze_49 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_710, 2), kwargs = {})
#   %expand_145 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_49, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_49 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_145,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_49
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_136 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_136', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_136', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_136(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (109051904 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/f7/cf7yjyrahhnajl4ylhtoalbu64ekqz4coo3twb2a44wlbyx2y22v.py
# Topologically Sorted Source Nodes: [cos, sin, view_76, key_states_76, k_25, chunk_51, setitem_50, mul_229, neg_51, cat_51, mul_230, k_embed_25, key_states_77], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_51 => cat_51
#   chunk_51 => split_51
#   cos => index
#   k_25 => convert_element_type_734
#   k_embed_25 => add_202
#   key_states_76 => permute_304
#   key_states_77 => convert_element_type_736
#   mul_229 => mul_279
#   mul_230 => mul_280
#   neg_51 => neg_76
#   setitem_50 => index_put_50, select_452, select_453, view_353
#   sin => index_1
#   view_76 => view_351
# Graph fragment:
#   %select_scatter_default_99 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_99]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_351 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_176, [1, 1, 2, 128]), kwargs = {})
#   %permute_304 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_351, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_734 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_304, torch.float32), kwargs = {})
#   %split_51 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_734, 64, -1), kwargs = {})
#   %select_452 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_99, 0, 0), kwargs = {})
#   %select_453 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_452, 0, 25), kwargs = {})
#   %mul_279 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_734, %index), kwargs = {})
#   %neg_76 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_103,), kwargs = {})
#   %cat_51 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_76, %getitem_102], -1), kwargs = {})
#   %mul_280 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_51, %index_1), kwargs = {})
#   %add_202 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_279, %mul_280), kwargs = {})
#   %convert_element_type_736 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_202, torch.bfloat16), kwargs = {})
#   %view_353 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_736, [2, 1, 128]), kwargs = {})
#   %index_put_50 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_453, [None, None, %arg1_1], %view_353), kwargs = {})
#   return %index_put_50
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_137 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_137', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_137', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_137(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (52428800 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/6t/c6tg47yn3c4pmtxdwjlcexj7uhurvx34kchmvkavfn55zsiaogii.py
# Topologically Sorted Source Nodes: [setitem_51, view_77, value_states_51], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_51 => index_put_51, select_457, select_458, view_354
#   value_states_51 => permute_305
#   view_77 => view_352
# Graph fragment:
#   %buf692 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf692]
#   %select_scatter_default_99 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_99]
#   %select_int_50 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_99, 0, 0), kwargs = {})
#   %select_scatter_default_100 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_50, %index_put_50, 0, 25), kwargs = {})
#   %select_scatter_default_101 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_99, %select_scatter_default_100, 0, 0), kwargs = {})
#   %select_457 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_101, 0, 1), kwargs = {})
#   %select_458 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_457, 0, 25), kwargs = {})
#   %view_352 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_177, [1, 1, 2, 128]), kwargs = {})
#   %permute_305 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_352, [0, 2, 1, 3]), kwargs = {})
#   %view_354 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_305, [2, 1, 128]), kwargs = {})
#   %index_put_51 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_458, [None, None, %arg1_1], %view_354), kwargs = {})
#   return %index_put_51
triton_poi_fused_index_put_select_transpose_view_138 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_138', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_138', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_138(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (52428800 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (111149056 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 25, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/kg/ckgdhqy46xomboumn7w2siadpheyo32vjagufegvvqfy5pngso2z.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf695 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf695]
#   %buf692 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf692]
#   %select_scatter_default_99 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_99]
#   %select_int_50 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_99, 0, 0), kwargs = {})
#   %select_scatter_default_100 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_50, %index_put_50, 0, 25), kwargs = {})
#   %select_scatter_default_101 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_99, %select_scatter_default_100, 0, 0), kwargs = {})
#   %select_int_51 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_101, 0, 1), kwargs = {})
#   %select_scatter_default_102 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_51, %index_put_51, 0, 25), kwargs = {})
#   %select_scatter_default_103 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_101, %select_scatter_default_102, 0, 1), kwargs = {})
#   return %select_scatter_default_103
triton_poi_fused_139 = async_compile.triton('triton_poi_fused_139', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_139', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_139(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 25, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/t6/ct6u27i4vxnizhbimlcd2fdrchtx4mmcwtp2nylxmphte42rfnfi.py
# Topologically Sorted Source Nodes: [attn_output_100], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_100 => clone_50, convert_element_type_738, expand_150, mul_282, permute_306, select_462, select_463, unsqueeze_50, view_355
# Graph fragment:
#   %select_scatter_default_103 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_103]
#   %select_462 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_103, 0, 0), kwargs = {})
#   %select_463 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_462, 0, 25), kwargs = {})
#   %convert_element_type_738 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_463, torch.float32), kwargs = {})
#   %unsqueeze_50 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_738, 2), kwargs = {})
#   %expand_150 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_50, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_50 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_150,), kwargs = {memory_format: torch.contiguous_format})
#   %view_355 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_50, [1, 16, 8192, 128]), kwargs = {})
#   %permute_306 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_355, [0, 1, 3, 2]), kwargs = {})
#   %mul_282 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_306, 0.29730177875068026), kwargs = {})
#   return %expand_153
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_140 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_140', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_140', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_140(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (52428800 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/3p/c3p4tcqopxwlilzxwlw5wpruxe2qvox6wqyylcfodigc4cmidmyu.py
# Topologically Sorted Source Nodes: [setitem_51, attn_output_100], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_100 => clone_51, convert_element_type_739, expand_151, unsqueeze_51
#   setitem_51 => select_460, select_461
# Graph fragment:
#   %select_scatter_default_103 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_103]
#   %select_460 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_103, 0, 1), kwargs = {})
#   %select_461 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_460, 0, 25), kwargs = {})
#   %convert_element_type_739 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_461, torch.float32), kwargs = {})
#   %unsqueeze_51 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_739, 2), kwargs = {})
#   %expand_151 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_51, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_51 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_151,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_51
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_141 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_141', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_141', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_141(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (111149056 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/fi/cfittugo2ivwqaejprddxl3hj7jqx7nhoezcvepnvkrhbxrh2zr3.py
# Topologically Sorted Source Nodes: [cos, sin, view_79, key_states_79, k_26, chunk_53, setitem_52, mul_238, neg_53, cat_53, mul_239, k_embed_26, key_states_80], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_53 => cat_53
#   chunk_53 => split_53
#   cos => index
#   k_26 => convert_element_type_763
#   k_embed_26 => add_210
#   key_states_79 => permute_316
#   key_states_80 => convert_element_type_765
#   mul_238 => mul_290
#   mul_239 => mul_291
#   neg_53 => neg_79
#   setitem_52 => index_put_52, select_470, select_471, view_367
#   sin => index_1
#   view_79 => view_365
# Graph fragment:
#   %select_scatter_default_103 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_103]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_365 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_183, [1, 1, 2, 128]), kwargs = {})
#   %permute_316 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_365, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_763 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_316, torch.float32), kwargs = {})
#   %split_53 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_763, 64, -1), kwargs = {})
#   %select_470 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_103, 0, 0), kwargs = {})
#   %select_471 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_470, 0, 26), kwargs = {})
#   %mul_290 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_763, %index), kwargs = {})
#   %neg_79 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_107,), kwargs = {})
#   %cat_53 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_79, %getitem_106], -1), kwargs = {})
#   %mul_291 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_53, %index_1), kwargs = {})
#   %add_210 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_290, %mul_291), kwargs = {})
#   %convert_element_type_765 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_210, torch.bfloat16), kwargs = {})
#   %view_367 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_765, [2, 1, 128]), kwargs = {})
#   %index_put_52 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_471, [None, None, %arg1_1], %view_367), kwargs = {})
#   return %index_put_52
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_142 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_142', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_142', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_142(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (54525952 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4s/c4s6m6anlxxcbamvosjzjaphq3iqar7emit74zjehlhfdkjo5bgp.py
# Topologically Sorted Source Nodes: [setitem_53, view_80, value_states_53], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_53 => index_put_53, select_475, select_476, view_368
#   value_states_53 => permute_317
#   view_80 => view_366
# Graph fragment:
#   %buf720 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf720]
#   %select_scatter_default_103 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_103]
#   %select_int_52 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_103, 0, 0), kwargs = {})
#   %select_scatter_default_104 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_52, %index_put_52, 0, 26), kwargs = {})
#   %select_scatter_default_105 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_103, %select_scatter_default_104, 0, 0), kwargs = {})
#   %select_475 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_105, 0, 1), kwargs = {})
#   %select_476 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_475, 0, 26), kwargs = {})
#   %view_366 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_184, [1, 1, 2, 128]), kwargs = {})
#   %permute_317 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_366, [0, 2, 1, 3]), kwargs = {})
#   %view_368 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_317, [2, 1, 128]), kwargs = {})
#   %index_put_53 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_476, [None, None, %arg1_1], %view_368), kwargs = {})
#   return %index_put_53
triton_poi_fused_index_put_select_transpose_view_143 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_143', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_143', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_143(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (54525952 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (113246208 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 26, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/it/citq6uwhux3n2hwr7muuja3nvjw5vq3efqtkl5r4przmqawmmmxp.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf723 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf723]
#   %buf720 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf720]
#   %select_scatter_default_103 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_103]
#   %select_int_52 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_103, 0, 0), kwargs = {})
#   %select_scatter_default_104 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_52, %index_put_52, 0, 26), kwargs = {})
#   %select_scatter_default_105 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_103, %select_scatter_default_104, 0, 0), kwargs = {})
#   %select_int_53 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_105, 0, 1), kwargs = {})
#   %select_scatter_default_106 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_53, %index_put_53, 0, 26), kwargs = {})
#   %select_scatter_default_107 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_105, %select_scatter_default_106, 0, 1), kwargs = {})
#   return %select_scatter_default_107
triton_poi_fused_144 = async_compile.triton('triton_poi_fused_144', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_144', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 939524096}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_144(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x3 = (xindex % 58720256)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 26, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/aw/caw2hym6lyia4no5mtgjapogt4jwkic2orrdm3nj44qotm6nqsis.py
# Topologically Sorted Source Nodes: [attn_output_104], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_104 => clone_52, convert_element_type_767, expand_156, mul_293, permute_318, select_480, select_481, unsqueeze_52, view_369
# Graph fragment:
#   %select_scatter_default_107 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_107]
#   %select_480 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_107, 0, 0), kwargs = {})
#   %select_481 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_480, 0, 26), kwargs = {})
#   %convert_element_type_767 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_481, torch.float32), kwargs = {})
#   %unsqueeze_52 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_767, 2), kwargs = {})
#   %expand_156 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_52, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_52 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_156,), kwargs = {memory_format: torch.contiguous_format})
#   %view_369 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_52, [1, 16, 8192, 128]), kwargs = {})
#   %permute_318 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_369, [0, 1, 3, 2]), kwargs = {})
#   %mul_293 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_318, 0.29730177875068026), kwargs = {})
#   return %expand_159
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_145 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_145', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_145', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_145(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (54525952 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/rm/crmeikelqfbraiaookp4psddhfxsfpbn2vc64vantc6tz4u6twyu.py
# Topologically Sorted Source Nodes: [setitem_53, attn_output_104], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_104 => clone_53, convert_element_type_768, expand_157, unsqueeze_53
#   setitem_53 => select_478, select_479
# Graph fragment:
#   %select_scatter_default_107 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_107]
#   %select_478 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_107, 0, 1), kwargs = {})
#   %select_479 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_478, 0, 26), kwargs = {})
#   %convert_element_type_768 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_479, torch.float32), kwargs = {})
#   %unsqueeze_53 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_768, 2), kwargs = {})
#   %expand_157 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_53, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_53 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_157,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_53
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_146 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_146', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_146', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_146(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (113246208 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ry/cry7eojc5ufrevyfqhmicbxrt7xreecojszfom2svz5tk26binbs.py
# Topologically Sorted Source Nodes: [cos, sin, view_82, key_states_82, k_27, chunk_55, setitem_54, mul_247, neg_55, cat_55, mul_248, k_embed_27, key_states_83], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
# Source node to ATen node mapping:
#   cat_55 => cat_55
#   chunk_55 => split_55
#   cos => index
#   k_27 => convert_element_type_792
#   k_embed_27 => add_218
#   key_states_82 => permute_328
#   key_states_83 => convert_element_type_794
#   mul_247 => mul_301
#   mul_248 => mul_302
#   neg_55 => neg_82
#   setitem_54 => index_put_54, select_488, select_489, view_381
#   sin => index_1
#   view_82 => view_379
# Graph fragment:
#   %select_scatter_default_107 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_107]
#   %index : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg0_1, [%arg1_1]), kwargs = {})
#   %index_1 : Tensor "bf16[1, 128][128, 1]cuda:0"[num_users=56] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg2_1, [%arg1_1]), kwargs = {})
#   %view_379 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_190, [1, 1, 2, 128]), kwargs = {})
#   %permute_328 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_379, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_792 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_328, torch.float32), kwargs = {})
#   %split_55 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_792, 64, -1), kwargs = {})
#   %select_488 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_107, 0, 0), kwargs = {})
#   %select_489 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_488, 0, 27), kwargs = {})
#   %mul_301 : Tensor "f32[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_792, %index), kwargs = {})
#   %neg_82 : Tensor "f32[1, 2, 1, 64][128, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_111,), kwargs = {})
#   %cat_55 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_82, %getitem_110], -1), kwargs = {})
#   %mul_302 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_55, %index_1), kwargs = {})
#   %add_218 : Tensor "f32[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_301, %mul_302), kwargs = {})
#   %convert_element_type_794 : Tensor "bf16[1, 2, 1, 128][256, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_218, torch.bfloat16), kwargs = {})
#   %view_381 : Tensor "bf16[2, 1, 128][128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%convert_element_type_794, [2, 1, 128]), kwargs = {})
#   %index_put_54 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_489, [None, None, %arg1_1], %view_381), kwargs = {})
#   return %index_put_54
triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_147 = async_compile.triton('triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_147', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_147', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_147(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (56623104 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5a/c5a464a4r3lg5nro6wo2ljb4ktaf5qcwzwltifynsuffickdgqlx.py
# Topologically Sorted Source Nodes: [setitem_55, view_83, value_states_55], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_55 => index_put_55, select_493, select_494, view_382
#   value_states_55 => permute_329
#   view_83 => view_380
# Graph fragment:
#   %buf747 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf747]
#   %select_scatter_default_107 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_107]
#   %select_int_54 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_107, 0, 0), kwargs = {})
#   %select_scatter_default_108 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_54, %index_put_54, 0, 27), kwargs = {})
#   %select_scatter_default_109 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_107, %select_scatter_default_108, 0, 0), kwargs = {})
#   %select_493 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_109, 0, 1), kwargs = {})
#   %select_494 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_493, 0, 27), kwargs = {})
#   %view_380 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_191, [1, 1, 2, 128]), kwargs = {})
#   %permute_329 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_380, [0, 2, 1, 3]), kwargs = {})
#   %view_382 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_329, [2, 1, 128]), kwargs = {})
#   %index_put_55 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_494, [None, None, %arg1_1], %view_382), kwargs = {})
#   return %index_put_55
triton_poi_fused_index_put_select_transpose_view_148 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_148', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_148', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_148(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (56623104 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (115343360 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 27, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/nf/cnfog4bnvcvjs3mveqb2vs7ykd6mi7ksf7slj3bh44rmbit4ga2o.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.copy_]
# Source node to ATen node mapping:
# Graph fragment:
#   %buf750 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf750]
#   %buf747 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf747]
#   %select_scatter_default_107 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_107]
#   %select_scatter_default_111 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_111]
#   %copy_ : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %buf5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf5]
#   %buf8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf8]
#   %select_scatter_default_3 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_int_54 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_107, 0, 0), kwargs = {})
#   %select_scatter_default_108 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_54, %index_put_54, 0, 27), kwargs = {})
#   %select_scatter_default_109 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_107, %select_scatter_default_108, 0, 0), kwargs = {})
#   %select_int_55 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_109, 0, 1), kwargs = {})
#   %select_scatter_default_110 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_55, %index_put_55, 0, 27), kwargs = {})
#   %select_scatter_default_111 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_109, %select_scatter_default_110, 0, 1), kwargs = {})
#   %copy_ : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %select_scatter_default_111), kwargs = {})
#   return %select_scatter_default_111,%buf775
triton_poi_fused_copy__149 = async_compile.triton('triton_poi_fused_copy__149', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 134217728}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_copy__149', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 2, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 1644167168}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_copy__149(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
    xnumel = 117440512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 58720256
    x1 = ((xindex // 2097152) % 28)
    x0 = (xindex % 2097152)
    x4 = (xindex % 58720256)
    x3 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (58720256 + x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x3), None).to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 27, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x3), tmp18, None)
    tl.store(out_ptr1 + (x3), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/vk/cvk6mfi55kolkbzhuo4lk55aqtg2giip4je2paskxjjhs3nb35wd.py
# Topologically Sorted Source Nodes: [attn_output_108], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_108 => clone_54, convert_element_type_796, expand_162, mul_304, permute_330, select_498, select_499, unsqueeze_54, view_383
# Graph fragment:
#   %select_scatter_default_111 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_111]
#   %select_498 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_111, 0, 0), kwargs = {})
#   %select_499 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_498, 0, 27), kwargs = {})
#   %convert_element_type_796 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_499, torch.float32), kwargs = {})
#   %unsqueeze_54 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_796, 2), kwargs = {})
#   %expand_162 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_54, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_54 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_162,), kwargs = {memory_format: torch.contiguous_format})
#   %view_383 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_54, [1, 16, 8192, 128]), kwargs = {})
#   %permute_330 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_383, [0, 1, 3, 2]), kwargs = {})
#   %mul_304 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_330, 0.29730177875068026), kwargs = {})
#   return %expand_165
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_150 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_150', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_150', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_150(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (56623104 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qc/cqcxte55kcbiiu5z35rwu2igopys7jn56jbl2nmq7xppzce76zev.py
# Topologically Sorted Source Nodes: [setitem_55, attn_output_108], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_108 => clone_55, convert_element_type_797, expand_163, unsqueeze_55
#   setitem_55 => select_496, select_497
# Graph fragment:
#   %select_scatter_default_111 : Tensor "bf16[2, 28, 1, 2, 8192, 128][58720256, 2097152, 117440512, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_111]
#   %select_496 : Tensor "bf16[28, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_111, 0, 1), kwargs = {})
#   %select_497 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_496, 0, 27), kwargs = {})
#   %convert_element_type_797 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_497, torch.float32), kwargs = {})
#   %unsqueeze_55 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_797, 2), kwargs = {})
#   %expand_163 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_55, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_55 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_163,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_55
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_151 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_151', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_151', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_151(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (115343360 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/rf/crfzd4khjxy74embd27g3gtpv3vgmz6ekmt7ktbvfpr43gdpafhj.py
# Topologically Sorted Source Nodes: [hidden_states_131, hidden_states_134, hidden_states_136, hidden_states_139, to_224, pow_57, variance_56, add_168, rsqrt_56, mul_252, hidden_56, hidden_states_140], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_168 => add_224
#   hidden_56 => convert_element_type_813
#   hidden_states_131 => add_212
#   hidden_states_134 => add_215
#   hidden_states_136 => add_220
#   hidden_states_139 => add_223
#   hidden_states_140 => mul_309
#   mul_252 => mul_308
#   pow_57 => pow_57
#   rsqrt_56 => rsqrt_56
#   to_224 => convert_element_type_812
#   variance_56 => mean_56
# Graph fragment:
#   %add_207 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_207]
#   %mm_185 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_185]
#   %mm_188 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_188]
#   %mm_192 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_192]
#   %mm_195 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_195]
#   %add_223 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_223]
#   %buf770 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf770]
#   %arg257_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg257_1]
#   %add_212 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_207, %mm_185), kwargs = {})
#   %add_215 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_212, %mm_188), kwargs = {})
#   %add_220 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_215, %mm_192), kwargs = {})
#   %add_223 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_220, %mm_195), kwargs = {})
#   %convert_element_type_812 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_223, torch.float32), kwargs = {})
#   %pow_57 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_812, 2), kwargs = {})
#   %mean_56 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_57, [-1], True), kwargs = {})
#   %add_224 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_56, 1e-05), kwargs = {})
#   %rsqrt_56 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_224,), kwargs = {})
#   %mul_308 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_223, %rsqrt_56), kwargs = {})
#   %convert_element_type_813 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_308, torch.bfloat16), kwargs = {})
#   %mul_309 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_813, %arg257_1), kwargs = {})
#   return %add_223,%buf770,%mul_309
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_152 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_152', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_152', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 32768}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_152(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp8 = tmp6 + tmp7
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp9 * tmp9
        tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
        tmp13 = _tmp12 + tmp11
        _tmp12 = tl.where(r0_mask, tmp13, _tmp12)
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp8, r0_mask)
    tmp12 = tl.sum(_tmp12, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp14 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp12 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')

def partition_0(args):
    arg5_1, arg4_1, arg6_1, arg7_1, arg3_1, arg1_1, arg0_1, arg2_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1 = args
    args.clear()
    assert_size_stride(arg5_1, (1, 2048), (2048, 1))
    assert_size_stride(arg4_1, (2048, ), (1, ))
    assert_size_stride(arg6_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg7_1, (256, 2048), (2048, 1))
    assert_size_stride(arg3_1, (2, 28, 1, 2, 8192, 128), (58720256, 2097152, 2097152, 1048576, 128, 1))
    assert_size_stride(arg1_1, (1, ), (1, ))
    assert_size_stride(arg0_1, (32768, 128), (128, 1))
    assert_size_stride(arg2_1, (32768, 128), (128, 1))
    assert_size_stride(arg8_1, (256, 2048), (2048, 1))
    assert_size_stride(arg9_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg10_1, (2048, ), (1, ))
    assert_size_stride(arg11_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg12_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg13_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg14_1, (2048, ), (1, ))
    assert_size_stride(arg15_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg16_1, (256, 2048), (2048, 1))
    assert_size_stride(arg17_1, (256, 2048), (2048, 1))
    assert_size_stride(arg18_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg19_1, (2048, ), (1, ))
    assert_size_stride(arg20_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg21_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg22_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg23_1, (2048, ), (1, ))
    assert_size_stride(arg24_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg25_1, (256, 2048), (2048, 1))
    assert_size_stride(arg26_1, (256, 2048), (2048, 1))
    assert_size_stride(arg27_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg28_1, (2048, ), (1, ))
    assert_size_stride(arg29_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg30_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg31_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg32_1, (2048, ), (1, ))
    assert_size_stride(arg33_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg34_1, (256, 2048), (2048, 1))
    assert_size_stride(arg35_1, (256, 2048), (2048, 1))
    assert_size_stride(arg36_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg37_1, (2048, ), (1, ))
    assert_size_stride(arg38_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg39_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg40_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg41_1, (2048, ), (1, ))
    assert_size_stride(arg42_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg43_1, (256, 2048), (2048, 1))
    assert_size_stride(arg44_1, (256, 2048), (2048, 1))
    assert_size_stride(arg45_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg46_1, (2048, ), (1, ))
    assert_size_stride(arg47_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg48_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg49_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg50_1, (2048, ), (1, ))
    assert_size_stride(arg51_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg52_1, (256, 2048), (2048, 1))
    assert_size_stride(arg53_1, (256, 2048), (2048, 1))
    assert_size_stride(arg54_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg55_1, (2048, ), (1, ))
    assert_size_stride(arg56_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg57_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg58_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg59_1, (2048, ), (1, ))
    assert_size_stride(arg60_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg61_1, (256, 2048), (2048, 1))
    assert_size_stride(arg62_1, (256, 2048), (2048, 1))
    assert_size_stride(arg63_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg64_1, (2048, ), (1, ))
    assert_size_stride(arg65_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg66_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg67_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg68_1, (2048, ), (1, ))
    assert_size_stride(arg69_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg70_1, (256, 2048), (2048, 1))
    assert_size_stride(arg71_1, (256, 2048), (2048, 1))
    assert_size_stride(arg72_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg73_1, (2048, ), (1, ))
    assert_size_stride(arg74_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg75_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg76_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg77_1, (2048, ), (1, ))
    assert_size_stride(arg78_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg79_1, (256, 2048), (2048, 1))
    assert_size_stride(arg80_1, (256, 2048), (2048, 1))
    assert_size_stride(arg81_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg82_1, (2048, ), (1, ))
    assert_size_stride(arg83_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg84_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg85_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg86_1, (2048, ), (1, ))
    assert_size_stride(arg87_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg88_1, (256, 2048), (2048, 1))
    assert_size_stride(arg89_1, (256, 2048), (2048, 1))
    assert_size_stride(arg90_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg91_1, (2048, ), (1, ))
    assert_size_stride(arg92_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg93_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg94_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg95_1, (2048, ), (1, ))
    assert_size_stride(arg96_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg97_1, (256, 2048), (2048, 1))
    assert_size_stride(arg98_1, (256, 2048), (2048, 1))
    assert_size_stride(arg99_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg100_1, (2048, ), (1, ))
    assert_size_stride(arg101_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg102_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg103_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg104_1, (2048, ), (1, ))
    assert_size_stride(arg105_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg106_1, (256, 2048), (2048, 1))
    assert_size_stride(arg107_1, (256, 2048), (2048, 1))
    assert_size_stride(arg108_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg109_1, (2048, ), (1, ))
    assert_size_stride(arg110_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg111_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg112_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg113_1, (2048, ), (1, ))
    assert_size_stride(arg114_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg115_1, (256, 2048), (2048, 1))
    assert_size_stride(arg116_1, (256, 2048), (2048, 1))
    assert_size_stride(arg117_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg118_1, (2048, ), (1, ))
    assert_size_stride(arg119_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg120_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg121_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg122_1, (2048, ), (1, ))
    assert_size_stride(arg123_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg124_1, (256, 2048), (2048, 1))
    assert_size_stride(arg125_1, (256, 2048), (2048, 1))
    assert_size_stride(arg126_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg127_1, (2048, ), (1, ))
    assert_size_stride(arg128_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg129_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg130_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg131_1, (2048, ), (1, ))
    assert_size_stride(arg132_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg133_1, (256, 2048), (2048, 1))
    assert_size_stride(arg134_1, (256, 2048), (2048, 1))
    assert_size_stride(arg135_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg136_1, (2048, ), (1, ))
    assert_size_stride(arg137_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg138_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg139_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg140_1, (2048, ), (1, ))
    assert_size_stride(arg141_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg142_1, (256, 2048), (2048, 1))
    assert_size_stride(arg143_1, (256, 2048), (2048, 1))
    assert_size_stride(arg144_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg145_1, (2048, ), (1, ))
    assert_size_stride(arg146_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg147_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg148_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg149_1, (2048, ), (1, ))
    assert_size_stride(arg150_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg151_1, (256, 2048), (2048, 1))
    assert_size_stride(arg152_1, (256, 2048), (2048, 1))
    assert_size_stride(arg153_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg154_1, (2048, ), (1, ))
    assert_size_stride(arg155_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg156_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg157_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg158_1, (2048, ), (1, ))
    assert_size_stride(arg159_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg160_1, (256, 2048), (2048, 1))
    assert_size_stride(arg161_1, (256, 2048), (2048, 1))
    assert_size_stride(arg162_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg163_1, (2048, ), (1, ))
    assert_size_stride(arg164_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg165_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg166_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg167_1, (2048, ), (1, ))
    assert_size_stride(arg168_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg169_1, (256, 2048), (2048, 1))
    assert_size_stride(arg170_1, (256, 2048), (2048, 1))
    assert_size_stride(arg171_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg172_1, (2048, ), (1, ))
    assert_size_stride(arg173_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg174_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg175_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg176_1, (2048, ), (1, ))
    assert_size_stride(arg177_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg178_1, (256, 2048), (2048, 1))
    assert_size_stride(arg179_1, (256, 2048), (2048, 1))
    assert_size_stride(arg180_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg181_1, (2048, ), (1, ))
    assert_size_stride(arg182_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg183_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg184_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg185_1, (2048, ), (1, ))
    assert_size_stride(arg186_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg187_1, (256, 2048), (2048, 1))
    assert_size_stride(arg188_1, (256, 2048), (2048, 1))
    assert_size_stride(arg189_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg190_1, (2048, ), (1, ))
    assert_size_stride(arg191_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg192_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg193_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg194_1, (2048, ), (1, ))
    assert_size_stride(arg195_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg196_1, (256, 2048), (2048, 1))
    assert_size_stride(arg197_1, (256, 2048), (2048, 1))
    assert_size_stride(arg198_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg199_1, (2048, ), (1, ))
    assert_size_stride(arg200_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg201_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg202_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg203_1, (2048, ), (1, ))
    assert_size_stride(arg204_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg205_1, (256, 2048), (2048, 1))
    assert_size_stride(arg206_1, (256, 2048), (2048, 1))
    assert_size_stride(arg207_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg208_1, (2048, ), (1, ))
    assert_size_stride(arg209_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg210_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg211_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg212_1, (2048, ), (1, ))
    assert_size_stride(arg213_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg214_1, (256, 2048), (2048, 1))
    assert_size_stride(arg215_1, (256, 2048), (2048, 1))
    assert_size_stride(arg216_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg217_1, (2048, ), (1, ))
    assert_size_stride(arg218_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg219_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg220_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg221_1, (2048, ), (1, ))
    assert_size_stride(arg222_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg223_1, (256, 2048), (2048, 1))
    assert_size_stride(arg224_1, (256, 2048), (2048, 1))
    assert_size_stride(arg225_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg226_1, (2048, ), (1, ))
    assert_size_stride(arg227_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg228_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg229_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg230_1, (2048, ), (1, ))
    assert_size_stride(arg231_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg232_1, (256, 2048), (2048, 1))
    assert_size_stride(arg233_1, (256, 2048), (2048, 1))
    assert_size_stride(arg234_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg235_1, (2048, ), (1, ))
    assert_size_stride(arg236_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg237_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg238_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg239_1, (2048, ), (1, ))
    assert_size_stride(arg240_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg241_1, (256, 2048), (2048, 1))
    assert_size_stride(arg242_1, (256, 2048), (2048, 1))
    assert_size_stride(arg243_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg244_1, (2048, ), (1, ))
    assert_size_stride(arg245_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg246_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg247_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg248_1, (2048, ), (1, ))
    assert_size_stride(arg249_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg250_1, (256, 2048), (2048, 1))
    assert_size_stride(arg251_1, (256, 2048), (2048, 1))
    assert_size_stride(arg252_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg253_1, (2048, ), (1, ))
    assert_size_stride(arg254_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg255_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg256_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg257_1, (2048, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf1 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(arg5_1, arg4_1, buf1, 1, 2048, stream=stream0)
        del arg4_1
        buf2 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf1, reinterpret_tensor(arg6_1, (2048, 2048), (1, 2048), 0), out=buf2)
        del arg6_1
        buf3 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf1, reinterpret_tensor(arg7_1, (2048, 256), (1, 2048), 0), out=buf3)
        del arg7_1
        buf4 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [view_1, key_states_1, k, chunk_1, cos, sin, key_cache, mul_4, neg_1, cat_1, mul_5, k_embed, key_states_2, setitem], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.split, aten.index, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_1.run(arg3_1, buf4, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [view_1, key_states_1, k, chunk_1, cos, sin, key_cache, mul_4, neg_1, cat_1, mul_5, k_embed, key_states_2, setitem], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.split, aten.index, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf3, arg0_1, arg2_1, buf4, 256, stream=stream0)
        buf6 = buf3; del buf3  # reuse
        # Topologically Sorted Source Nodes: [value_states], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf1, reinterpret_tensor(arg8_1, (2048, 256), (1, 2048), 0), out=buf6)
        del arg8_1
        buf7 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_1, view_2, value_states_1], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_3.run(buf4, arg3_1, buf7, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_1, view_2, value_states_1], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf6, buf7, 256, stream=stream0)
        buf9 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_5.run(buf7, buf4, arg3_1, buf9, 117440512, stream=stream0)
        buf10 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view, query_states_1, q, chunk, cos, mul_2, neg, cat, sin, mul_3, q_embed, attn_output], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.split, aten.index, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf2, arg1_1, arg0_1, arg2_1, buf10, 2048, stream=stream0)
        buf11 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_7.run(buf9, buf11, 2048, 8192, stream=stream0)
        buf12 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view, query_states_1, q, chunk, cos, mul_2, neg, cat, sin, mul_3, q_embed, attn_output], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.split, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf10, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf11, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf12)
        buf16 = reinterpret_tensor(buf12, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf12  # reuse
        # Topologically Sorted Source Nodes: [attn_output, arange, attn_mask], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf16, arg1_1, 16, 8192, stream=stream0)
        buf17 = reinterpret_tensor(buf11, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf11  # reuse
        # Topologically Sorted Source Nodes: [setitem_1, attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_9.run(buf9, buf17, 16777216, stream=stream0)
        buf18 = reinterpret_tensor(buf10, (16, 1, 128), (128, 128, 1), 0); del buf10  # reuse
        # Topologically Sorted Source Nodes: [attn_output, arange, attn_mask, setitem_1], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf16, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf17, (16, 8192, 128), (1048576, 128, 1), 0), out=buf18)
        del buf16
        del buf17
        buf19 = reinterpret_tensor(buf2, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf2  # reuse
        # Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf18, buf19, 2048, stream=stream0)
        del buf18
        buf20 = buf1; del buf1  # reuse
        # Topologically Sorted Source Nodes: [attn_output, transpose_3, attn_output_2, attn_output_3], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf19, (1, 2048), (0, 1), 0), reinterpret_tensor(arg9_1, (2048, 2048), (1, 2048), 0), out=buf20)
        del arg9_1
        buf22 = reinterpret_tensor(buf19, (1, 2048), (2048, 1), 0); del buf19  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_1, to_6, pow_2, variance_1, add_4, rsqrt_1, mul_6, hidden_1, hidden_states_2], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(arg5_1, buf20, arg10_1, buf22, 1, 2048, stream=stream0)
        del arg10_1
        buf23 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_4], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf22, reinterpret_tensor(arg11_1, (2048, 6144), (1, 2048), 0), out=buf23)
        del arg11_1
        buf24 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_5], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf22, reinterpret_tensor(arg12_1, (2048, 6144), (1, 2048), 0), out=buf24)
        del arg12_1
        buf25 = buf23; del buf23  # reuse
        # Topologically Sorted Source Nodes: [silu, mul_8], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf25, buf24, 6144, stream=stream0)
        del buf24
        buf26 = buf22; del buf22  # reuse
        # Topologically Sorted Source Nodes: [silu, mul_8, hidden_states_3], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf25, reinterpret_tensor(arg13_1, (6144, 2048), (1, 6144), 0), out=buf26)
        del arg13_1
        del buf25
        buf28 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, to_8, pow_3, variance_2, add_6, rsqrt_2, mul_9, hidden_2, hidden_states_5], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(arg5_1, buf20, buf26, arg14_1, buf28, 1, 2048, stream=stream0)
        del arg14_1
        buf29 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_4], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf28, reinterpret_tensor(arg15_1, (2048, 2048), (1, 2048), 0), out=buf29)
        del arg15_1
        buf30 = buf6; del buf6  # reuse
        # Topologically Sorted Source Nodes: [key_states_3], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf28, reinterpret_tensor(arg16_1, (2048, 256), (1, 2048), 0), out=buf30)
        del arg16_1
        buf31 = buf7; del buf7  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_4, key_states_4, k_1, chunk_3, setitem_2, mul_13, neg_3, cat_3, mul_14, k_embed_1, key_states_5], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_14.run(buf9, buf31, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_4, key_states_4, k_1, chunk_3, setitem_2, mul_13, neg_3, cat_3, mul_14, k_embed_1, key_states_5], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf30, arg0_1, arg2_1, buf31, 256, stream=stream0)
        buf33 = buf30; del buf30  # reuse
        # Topologically Sorted Source Nodes: [value_states_2], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf28, reinterpret_tensor(arg17_1, (2048, 256), (1, 2048), 0), out=buf33)
        del arg17_1
        buf34 = buf4; del buf4  # reuse
        # Topologically Sorted Source Nodes: [setitem_3, view_5, value_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_15.run(buf31, buf9, buf34, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_3, view_5, value_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf33, buf34, 256, stream=stream0)
        del buf33
        buf36 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_16.run(buf34, buf31, buf9, buf36, 117440512, stream=stream0)
        del buf9
        buf37 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_3, query_states_5, q_1, chunk_2, mul_11, neg_2, cat_2, mul_12, q_embed_1, attn_output_4], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf29, arg1_1, arg0_1, arg2_1, buf37, 2048, stream=stream0)
        buf38 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_17.run(buf36, buf38, 2048, 8192, stream=stream0)
        buf39 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_3, query_states_5, q_1, chunk_2, mul_11, neg_2, cat_2, mul_12, q_embed_1, attn_output_4], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf37, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf38, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf39)
        buf43 = reinterpret_tensor(buf39, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf39  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4, arange_1, attn_mask_1], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf43, arg1_1, 16, 8192, stream=stream0)
        buf44 = reinterpret_tensor(buf38, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf38  # reuse
        # Topologically Sorted Source Nodes: [setitem_3, attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_18.run(buf36, buf44, 16777216, stream=stream0)
        buf45 = reinterpret_tensor(buf37, (16, 1, 128), (128, 128, 1), 0); del buf37  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4, arange_1, attn_mask_1, setitem_3], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf43, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf44, (16, 8192, 128), (1048576, 128, 1), 0), out=buf45)
        del buf43
        del buf44
        buf46 = reinterpret_tensor(buf29, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf29  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf45, buf46, 2048, stream=stream0)
        buf47 = buf28; del buf28  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4, transpose_7, attn_output_6, attn_output_7], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf46, (1, 2048), (0, 1), 0), reinterpret_tensor(arg18_1, (2048, 2048), (1, 2048), 0), out=buf47)
        del arg18_1
        buf49 = reinterpret_tensor(buf46, (1, 2048), (2048, 1), 0); del buf46  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, to_14, pow_4, variance_3, add_10, rsqrt_3, mul_15, hidden_3, hidden_states_7], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(arg5_1, buf20, buf26, buf47, arg19_1, buf49, 1, 2048, stream=stream0)
        del arg19_1
        buf50 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_11], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf49, reinterpret_tensor(arg20_1, (2048, 6144), (1, 2048), 0), out=buf50)
        del arg20_1
        buf51 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf49, reinterpret_tensor(arg21_1, (2048, 6144), (1, 2048), 0), out=buf51)
        del arg21_1
        buf52 = buf50; del buf50  # reuse
        # Topologically Sorted Source Nodes: [silu_1, mul_17], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf52, buf51, 6144, stream=stream0)
        del buf51
        buf53 = buf49; del buf49  # reuse
        # Topologically Sorted Source Nodes: [silu_1, mul_17, hidden_states_8], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf52, reinterpret_tensor(arg22_1, (6144, 2048), (1, 6144), 0), out=buf53)
        del arg22_1
        del buf52
        buf54 = buf20; del buf20  # reuse
        buf56 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, hidden_states_9, to_16, pow_5, variance_4, add_12, rsqrt_4, mul_18, hidden_4, hidden_states_10], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_20.run(buf54, arg5_1, buf26, buf47, buf53, arg23_1, buf56, 1, 2048, stream=stream0)
        del arg23_1
        del arg5_1
        del buf26
        del buf47
        buf57 = buf53; del buf53  # reuse
        # Topologically Sorted Source Nodes: [query_states_8], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf56, reinterpret_tensor(arg24_1, (2048, 2048), (1, 2048), 0), out=buf57)
        del arg24_1
        buf58 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_6], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf56, reinterpret_tensor(arg25_1, (2048, 256), (1, 2048), 0), out=buf58)
        del arg25_1
        buf59 = buf34; del buf34  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_7, key_states_7, k_2, chunk_5, setitem_4, mul_22, neg_5, cat_5, mul_23, k_embed_2, key_states_8], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_21.run(buf36, buf59, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_7, key_states_7, k_2, chunk_5, setitem_4, mul_22, neg_5, cat_5, mul_23, k_embed_2, key_states_8], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf58, arg0_1, arg2_1, buf59, 256, stream=stream0)
        buf61 = buf58; del buf58  # reuse
        # Topologically Sorted Source Nodes: [value_states_4], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf56, reinterpret_tensor(arg26_1, (2048, 256), (1, 2048), 0), out=buf61)
        del arg26_1
        del buf56
        buf62 = buf31; del buf31  # reuse
        # Topologically Sorted Source Nodes: [setitem_5, view_8, value_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_22.run(buf59, buf36, buf62, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_5, view_8, value_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf61, buf62, 256, stream=stream0)
        del buf61
        buf64 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_23.run(buf62, buf59, buf36, buf64, 117440512, stream=stream0)
        del buf36
        buf65 = reinterpret_tensor(buf45, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf45  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_6, query_states_9, q_2, chunk_4, mul_20, neg_4, cat_4, mul_21, q_embed_2, attn_output_8], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf57, arg1_1, arg0_1, arg2_1, buf65, 2048, stream=stream0)
        buf66 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_24.run(buf64, buf66, 2048, 8192, stream=stream0)
        buf67 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_6, query_states_9, q_2, chunk_4, mul_20, neg_4, cat_4, mul_21, q_embed_2, attn_output_8], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf65, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf66, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf67)
        buf71 = reinterpret_tensor(buf67, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf67  # reuse
        # Topologically Sorted Source Nodes: [attn_output_8, arange_2, attn_mask_2], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf71, arg1_1, 16, 8192, stream=stream0)
        buf72 = reinterpret_tensor(buf66, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf66  # reuse
        # Topologically Sorted Source Nodes: [setitem_5, attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_25.run(buf64, buf72, 16777216, stream=stream0)
        buf73 = reinterpret_tensor(buf65, (16, 1, 128), (128, 128, 1), 0); del buf65  # reuse
        # Topologically Sorted Source Nodes: [attn_output_8, arange_2, attn_mask_2, setitem_5], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf71, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf72, (16, 8192, 128), (1048576, 128, 1), 0), out=buf73)
        del buf71
        del buf72
        buf74 = reinterpret_tensor(buf57, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf57  # reuse
        # Topologically Sorted Source Nodes: [attn_output_8], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf73, buf74, 2048, stream=stream0)
        del buf73
        buf75 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_8, transpose_11, attn_output_10, attn_output_11], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf74, (1, 2048), (0, 1), 0), reinterpret_tensor(arg27_1, (2048, 2048), (1, 2048), 0), out=buf75)
        del arg27_1
        buf77 = reinterpret_tensor(buf74, (1, 2048), (2048, 1), 0); del buf74  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_11, to_22, pow_6, variance_5, add_16, rsqrt_5, mul_24, hidden_5, hidden_states_12], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf54, buf75, arg28_1, buf77, 1, 2048, stream=stream0)
        del arg28_1
        buf78 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_18], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf77, reinterpret_tensor(arg29_1, (2048, 6144), (1, 2048), 0), out=buf78)
        del arg29_1
        buf79 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_19], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf77, reinterpret_tensor(arg30_1, (2048, 6144), (1, 2048), 0), out=buf79)
        del arg30_1
        buf80 = buf78; del buf78  # reuse
        # Topologically Sorted Source Nodes: [silu_2, mul_26], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf80, buf79, 6144, stream=stream0)
        del buf79
        buf81 = buf77; del buf77  # reuse
        # Topologically Sorted Source Nodes: [silu_2, mul_26, hidden_states_13], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf80, reinterpret_tensor(arg31_1, (6144, 2048), (1, 6144), 0), out=buf81)
        del arg31_1
        del buf80
        buf83 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, to_24, pow_7, variance_6, add_18, rsqrt_6, mul_27, hidden_6, hidden_states_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf54, buf75, buf81, arg32_1, buf83, 1, 2048, stream=stream0)
        del arg32_1
        buf84 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf83, reinterpret_tensor(arg33_1, (2048, 2048), (1, 2048), 0), out=buf84)
        del arg33_1
        buf85 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_9], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf83, reinterpret_tensor(arg34_1, (2048, 256), (1, 2048), 0), out=buf85)
        del arg34_1
        buf86 = buf62; del buf62  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_10, key_states_10, k_3, chunk_7, setitem_6, mul_31, neg_7, cat_7, mul_32, k_embed_3, key_states_11], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_26.run(buf64, buf86, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_10, key_states_10, k_3, chunk_7, setitem_6, mul_31, neg_7, cat_7, mul_32, k_embed_3, key_states_11], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf85, arg0_1, arg2_1, buf86, 256, stream=stream0)
        buf88 = buf85; del buf85  # reuse
        # Topologically Sorted Source Nodes: [value_states_6], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf83, reinterpret_tensor(arg35_1, (2048, 256), (1, 2048), 0), out=buf88)
        del arg35_1
        del buf83
        buf89 = buf59; del buf59  # reuse
        # Topologically Sorted Source Nodes: [setitem_7, view_11, value_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_27.run(buf86, buf64, buf89, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_7, view_11, value_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf88, buf89, 256, stream=stream0)
        del buf88
        buf91 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_28.run(buf89, buf86, buf64, buf91, 117440512, stream=stream0)
        del buf64
        buf92 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_9, query_states_13, q_3, chunk_6, mul_29, neg_6, cat_6, mul_30, q_embed_3, attn_output_12], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf84, arg1_1, arg0_1, arg2_1, buf92, 2048, stream=stream0)
        buf93 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_29.run(buf91, buf93, 2048, 8192, stream=stream0)
        buf94 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_9, query_states_13, q_3, chunk_6, mul_29, neg_6, cat_6, mul_30, q_embed_3, attn_output_12], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf92, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf93, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf94)
        buf98 = reinterpret_tensor(buf94, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf94  # reuse
        # Topologically Sorted Source Nodes: [attn_output_12, arange_3, attn_mask_3], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf98, arg1_1, 16, 8192, stream=stream0)
        buf99 = reinterpret_tensor(buf93, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf93  # reuse
        # Topologically Sorted Source Nodes: [setitem_7, attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_30.run(buf91, buf99, 16777216, stream=stream0)
        buf100 = reinterpret_tensor(buf92, (16, 1, 128), (128, 128, 1), 0); del buf92  # reuse
        # Topologically Sorted Source Nodes: [attn_output_12, arange_3, attn_mask_3, setitem_7], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf98, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf99, (16, 8192, 128), (1048576, 128, 1), 0), out=buf100)
        del buf98
        del buf99
        buf101 = reinterpret_tensor(buf84, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf84  # reuse
        # Topologically Sorted Source Nodes: [attn_output_12], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf100, buf101, 2048, stream=stream0)
        buf102 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_12, transpose_15, attn_output_14, attn_output_15], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf101, (1, 2048), (0, 1), 0), reinterpret_tensor(arg36_1, (2048, 2048), (1, 2048), 0), out=buf102)
        del arg36_1
        buf104 = reinterpret_tensor(buf101, (1, 2048), (2048, 1), 0); del buf101  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, hidden_states_16, to_30, pow_8, variance_7, add_22, rsqrt_7, mul_33, hidden_7, hidden_states_17], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf54, buf75, buf81, buf102, arg37_1, buf104, 1, 2048, stream=stream0)
        del arg37_1
        buf105 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_25], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf104, reinterpret_tensor(arg38_1, (2048, 6144), (1, 2048), 0), out=buf105)
        del arg38_1
        buf106 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_26], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf104, reinterpret_tensor(arg39_1, (2048, 6144), (1, 2048), 0), out=buf106)
        del arg39_1
        buf107 = buf105; del buf105  # reuse
        # Topologically Sorted Source Nodes: [silu_3, mul_35], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf107, buf106, 6144, stream=stream0)
        del buf106
        buf108 = buf104; del buf104  # reuse
        # Topologically Sorted Source Nodes: [silu_3, mul_35, hidden_states_18], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf107, reinterpret_tensor(arg40_1, (6144, 2048), (1, 6144), 0), out=buf108)
        del arg40_1
        del buf107
        buf109 = buf54; del buf54  # reuse
        buf111 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, hidden_states_16, hidden_states_19, to_32, pow_9, variance_8, add_24, rsqrt_8, mul_36, hidden_8, hidden_states_20], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf109, buf75, buf81, buf102, buf108, arg41_1, buf111, 1, 2048, stream=stream0)
        del arg41_1
        del buf102
        del buf108
        del buf75
        buf112 = buf81; del buf81  # reuse
        # Topologically Sorted Source Nodes: [query_states_16], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf111, reinterpret_tensor(arg42_1, (2048, 2048), (1, 2048), 0), out=buf112)
        del arg42_1
        buf113 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf111, reinterpret_tensor(arg43_1, (2048, 256), (1, 2048), 0), out=buf113)
        del arg43_1
        buf114 = buf89; del buf89  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_13, key_states_13, k_4, chunk_9, setitem_8, mul_40, neg_9, cat_9, mul_41, k_embed_4, key_states_14], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_32.run(buf91, buf114, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_13, key_states_13, k_4, chunk_9, setitem_8, mul_40, neg_9, cat_9, mul_41, k_embed_4, key_states_14], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf113, arg0_1, arg2_1, buf114, 256, stream=stream0)
        buf116 = buf113; del buf113  # reuse
        # Topologically Sorted Source Nodes: [value_states_8], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf111, reinterpret_tensor(arg44_1, (2048, 256), (1, 2048), 0), out=buf116)
        del arg44_1
        del buf111
        buf117 = buf86; del buf86  # reuse
        # Topologically Sorted Source Nodes: [setitem_9, view_14, value_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_33.run(buf114, buf91, buf117, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_9, view_14, value_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf116, buf117, 256, stream=stream0)
        del buf116
        buf119 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_34.run(buf117, buf114, buf91, buf119, 117440512, stream=stream0)
        del buf91
        buf120 = reinterpret_tensor(buf100, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf100  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_12, query_states_17, q_4, chunk_8, mul_38, neg_8, cat_8, mul_39, q_embed_4, attn_output_16], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf112, arg1_1, arg0_1, arg2_1, buf120, 2048, stream=stream0)
        buf121 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_35.run(buf119, buf121, 2048, 8192, stream=stream0)
        buf122 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_12, query_states_17, q_4, chunk_8, mul_38, neg_8, cat_8, mul_39, q_embed_4, attn_output_16], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf120, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf121, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf122)
        buf126 = reinterpret_tensor(buf122, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf122  # reuse
        # Topologically Sorted Source Nodes: [attn_output_16, arange_4, attn_mask_4], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf126, arg1_1, 16, 8192, stream=stream0)
        buf127 = reinterpret_tensor(buf121, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf121  # reuse
        # Topologically Sorted Source Nodes: [setitem_9, attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_36.run(buf119, buf127, 16777216, stream=stream0)
        buf128 = reinterpret_tensor(buf120, (16, 1, 128), (128, 128, 1), 0); del buf120  # reuse
        # Topologically Sorted Source Nodes: [attn_output_16, arange_4, attn_mask_4, setitem_9], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf126, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf127, (16, 8192, 128), (1048576, 128, 1), 0), out=buf128)
        del buf126
        del buf127
        buf129 = reinterpret_tensor(buf112, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf112  # reuse
        # Topologically Sorted Source Nodes: [attn_output_16], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf128, buf129, 2048, stream=stream0)
        del buf128
        buf130 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_16, transpose_19, attn_output_18, attn_output_19], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf129, (1, 2048), (0, 1), 0), reinterpret_tensor(arg45_1, (2048, 2048), (1, 2048), 0), out=buf130)
        del arg45_1
        buf132 = reinterpret_tensor(buf129, (1, 2048), (2048, 1), 0); del buf129  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_21, to_38, pow_10, variance_9, add_28, rsqrt_9, mul_42, hidden_9, hidden_states_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf109, buf130, arg46_1, buf132, 1, 2048, stream=stream0)
        del arg46_1
        buf133 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_32], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf132, reinterpret_tensor(arg47_1, (2048, 6144), (1, 2048), 0), out=buf133)
        del arg47_1
        buf134 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_33], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf132, reinterpret_tensor(arg48_1, (2048, 6144), (1, 2048), 0), out=buf134)
        del arg48_1
        buf135 = buf133; del buf133  # reuse
        # Topologically Sorted Source Nodes: [silu_4, mul_44], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf135, buf134, 6144, stream=stream0)
        del buf134
        buf136 = buf132; del buf132  # reuse
        # Topologically Sorted Source Nodes: [silu_4, mul_44, hidden_states_23], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf135, reinterpret_tensor(arg49_1, (6144, 2048), (1, 6144), 0), out=buf136)
        del arg49_1
        del buf135
        buf138 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_21, hidden_states_24, to_40, pow_11, variance_10, add_30, rsqrt_10, mul_45, hidden_10, hidden_states_25], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf109, buf130, buf136, arg50_1, buf138, 1, 2048, stream=stream0)
        del arg50_1
        buf139 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_20], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf138, reinterpret_tensor(arg51_1, (2048, 2048), (1, 2048), 0), out=buf139)
        del arg51_1
        buf140 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_15], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf138, reinterpret_tensor(arg52_1, (2048, 256), (1, 2048), 0), out=buf140)
        del arg52_1
        buf141 = buf117; del buf117  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_16, key_states_16, k_5, chunk_11, setitem_10, mul_49, neg_11, cat_11, mul_50, k_embed_5, key_states_17], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_37.run(buf119, buf141, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_16, key_states_16, k_5, chunk_11, setitem_10, mul_49, neg_11, cat_11, mul_50, k_embed_5, key_states_17], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf140, arg0_1, arg2_1, buf141, 256, stream=stream0)
        buf143 = buf140; del buf140  # reuse
        # Topologically Sorted Source Nodes: [value_states_10], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf138, reinterpret_tensor(arg53_1, (2048, 256), (1, 2048), 0), out=buf143)
        del arg53_1
        del buf138
        buf144 = buf114; del buf114  # reuse
        # Topologically Sorted Source Nodes: [setitem_11, view_17, value_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_38.run(buf141, buf119, buf144, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_11, view_17, value_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf143, buf144, 256, stream=stream0)
        del buf143
        buf146 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_39.run(buf144, buf141, buf119, buf146, 117440512, stream=stream0)
        del buf119
        buf147 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_15, query_states_21, q_5, chunk_10, mul_47, neg_10, cat_10, mul_48, q_embed_5, attn_output_20], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf139, arg1_1, arg0_1, arg2_1, buf147, 2048, stream=stream0)
        buf148 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_40.run(buf146, buf148, 2048, 8192, stream=stream0)
        buf149 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_15, query_states_21, q_5, chunk_10, mul_47, neg_10, cat_10, mul_48, q_embed_5, attn_output_20], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf147, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf148, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf149)
        buf153 = reinterpret_tensor(buf149, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf149  # reuse
        # Topologically Sorted Source Nodes: [attn_output_20, arange_5, attn_mask_5], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf153, arg1_1, 16, 8192, stream=stream0)
        buf154 = reinterpret_tensor(buf148, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf148  # reuse
        # Topologically Sorted Source Nodes: [setitem_11, attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_41.run(buf146, buf154, 16777216, stream=stream0)
        buf155 = reinterpret_tensor(buf147, (16, 1, 128), (128, 128, 1), 0); del buf147  # reuse
        # Topologically Sorted Source Nodes: [attn_output_20, arange_5, attn_mask_5, setitem_11], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf153, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf154, (16, 8192, 128), (1048576, 128, 1), 0), out=buf155)
        del buf153
        del buf154
        buf156 = reinterpret_tensor(buf139, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf139  # reuse
        # Topologically Sorted Source Nodes: [attn_output_20], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf155, buf156, 2048, stream=stream0)
        buf157 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_20, transpose_23, attn_output_22, attn_output_23], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf156, (1, 2048), (0, 1), 0), reinterpret_tensor(arg54_1, (2048, 2048), (1, 2048), 0), out=buf157)
        del arg54_1
        buf159 = reinterpret_tensor(buf156, (1, 2048), (2048, 1), 0); del buf156  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_21, hidden_states_24, hidden_states_26, to_46, pow_12, variance_11, add_34, rsqrt_11, mul_51, hidden_11, hidden_states_27], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf109, buf130, buf136, buf157, arg55_1, buf159, 1, 2048, stream=stream0)
        del arg55_1
        buf160 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_39], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf159, reinterpret_tensor(arg56_1, (2048, 6144), (1, 2048), 0), out=buf160)
        del arg56_1
        buf161 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_40], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf159, reinterpret_tensor(arg57_1, (2048, 6144), (1, 2048), 0), out=buf161)
        del arg57_1
        buf162 = buf160; del buf160  # reuse
        # Topologically Sorted Source Nodes: [silu_5, mul_53], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf162, buf161, 6144, stream=stream0)
        del buf161
        buf163 = buf159; del buf159  # reuse
        # Topologically Sorted Source Nodes: [silu_5, mul_53, hidden_states_28], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf162, reinterpret_tensor(arg58_1, (6144, 2048), (1, 6144), 0), out=buf163)
        del arg58_1
        del buf162
        buf164 = buf109; del buf109  # reuse
        buf166 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_21, hidden_states_24, hidden_states_26, hidden_states_29, to_48, pow_13, variance_12, add_36, rsqrt_12, mul_54, hidden_12, hidden_states_30], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf164, buf130, buf136, buf157, buf163, arg59_1, buf166, 1, 2048, stream=stream0)
        del arg59_1
        del buf130
        del buf136
        del buf157
        buf167 = buf163; del buf163  # reuse
        # Topologically Sorted Source Nodes: [query_states_24], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf166, reinterpret_tensor(arg60_1, (2048, 2048), (1, 2048), 0), out=buf167)
        del arg60_1
        buf168 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_18], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf166, reinterpret_tensor(arg61_1, (2048, 256), (1, 2048), 0), out=buf168)
        del arg61_1
        buf169 = buf144; del buf144  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_19, key_states_19, k_6, chunk_13, setitem_12, mul_58, neg_13, cat_13, mul_59, k_embed_6, key_states_20], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_42.run(buf146, buf169, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_19, key_states_19, k_6, chunk_13, setitem_12, mul_58, neg_13, cat_13, mul_59, k_embed_6, key_states_20], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf168, arg0_1, arg2_1, buf169, 256, stream=stream0)
        buf171 = buf168; del buf168  # reuse
        # Topologically Sorted Source Nodes: [value_states_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf166, reinterpret_tensor(arg62_1, (2048, 256), (1, 2048), 0), out=buf171)
        del arg62_1
        del buf166
        buf172 = buf141; del buf141  # reuse
        # Topologically Sorted Source Nodes: [setitem_13, view_20, value_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_43.run(buf169, buf146, buf172, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_13, view_20, value_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf171, buf172, 256, stream=stream0)
        del buf171
        buf174 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_44.run(buf172, buf169, buf146, buf174, 117440512, stream=stream0)
        del buf146
        buf175 = reinterpret_tensor(buf155, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf155  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_18, query_states_25, q_6, chunk_12, mul_56, neg_12, cat_12, mul_57, q_embed_6, attn_output_24], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf167, arg1_1, arg0_1, arg2_1, buf175, 2048, stream=stream0)
        buf176 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_45.run(buf174, buf176, 2048, 8192, stream=stream0)
        buf177 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_18, query_states_25, q_6, chunk_12, mul_56, neg_12, cat_12, mul_57, q_embed_6, attn_output_24], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf175, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf176, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf177)
        buf181 = reinterpret_tensor(buf177, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf177  # reuse
        # Topologically Sorted Source Nodes: [attn_output_24, arange_6, attn_mask_6], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf181, arg1_1, 16, 8192, stream=stream0)
        buf182 = reinterpret_tensor(buf176, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf176  # reuse
        # Topologically Sorted Source Nodes: [setitem_13, attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_46.run(buf174, buf182, 16777216, stream=stream0)
        buf183 = reinterpret_tensor(buf175, (16, 1, 128), (128, 128, 1), 0); del buf175  # reuse
        # Topologically Sorted Source Nodes: [attn_output_24, arange_6, attn_mask_6, setitem_13], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf181, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf182, (16, 8192, 128), (1048576, 128, 1), 0), out=buf183)
        del buf181
        del buf182
        buf184 = reinterpret_tensor(buf167, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf167  # reuse
        # Topologically Sorted Source Nodes: [attn_output_24], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf183, buf184, 2048, stream=stream0)
        del buf183
        buf185 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_24, transpose_27, attn_output_26, attn_output_27], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf184, (1, 2048), (0, 1), 0), reinterpret_tensor(arg63_1, (2048, 2048), (1, 2048), 0), out=buf185)
        del arg63_1
        buf187 = reinterpret_tensor(buf184, (1, 2048), (2048, 1), 0); del buf184  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_31, to_54, pow_14, variance_13, add_40, rsqrt_13, mul_60, hidden_13, hidden_states_32], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf164, buf185, arg64_1, buf187, 1, 2048, stream=stream0)
        del arg64_1
        buf188 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_46], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf187, reinterpret_tensor(arg65_1, (2048, 6144), (1, 2048), 0), out=buf188)
        del arg65_1
        buf189 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_47], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf187, reinterpret_tensor(arg66_1, (2048, 6144), (1, 2048), 0), out=buf189)
        del arg66_1
        buf190 = buf188; del buf188  # reuse
        # Topologically Sorted Source Nodes: [silu_6, mul_62], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf190, buf189, 6144, stream=stream0)
        del buf189
        buf191 = buf187; del buf187  # reuse
        # Topologically Sorted Source Nodes: [silu_6, mul_62, hidden_states_33], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf190, reinterpret_tensor(arg67_1, (6144, 2048), (1, 6144), 0), out=buf191)
        del arg67_1
        del buf190
        buf193 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_31, hidden_states_34, to_56, pow_15, variance_14, add_42, rsqrt_14, mul_63, hidden_14, hidden_states_35], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf164, buf185, buf191, arg68_1, buf193, 1, 2048, stream=stream0)
        del arg68_1
        buf194 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_28], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf193, reinterpret_tensor(arg69_1, (2048, 2048), (1, 2048), 0), out=buf194)
        del arg69_1
        buf195 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_21], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf193, reinterpret_tensor(arg70_1, (2048, 256), (1, 2048), 0), out=buf195)
        del arg70_1
        buf196 = buf172; del buf172  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_22, key_states_22, k_7, chunk_15, setitem_14, mul_67, neg_15, cat_15, mul_68, k_embed_7, key_states_23], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_47.run(buf174, buf196, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_22, key_states_22, k_7, chunk_15, setitem_14, mul_67, neg_15, cat_15, mul_68, k_embed_7, key_states_23], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf195, arg0_1, arg2_1, buf196, 256, stream=stream0)
        buf198 = buf195; del buf195  # reuse
        # Topologically Sorted Source Nodes: [value_states_14], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf193, reinterpret_tensor(arg71_1, (2048, 256), (1, 2048), 0), out=buf198)
        del arg71_1
        del buf193
        buf199 = buf169; del buf169  # reuse
        # Topologically Sorted Source Nodes: [setitem_15, view_23, value_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_48.run(buf196, buf174, buf199, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_15, view_23, value_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf198, buf199, 256, stream=stream0)
        del buf198
        buf201 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_49.run(buf199, buf196, buf174, buf201, 117440512, stream=stream0)
        del buf174
        buf202 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_21, query_states_29, q_7, chunk_14, mul_65, neg_14, cat_14, mul_66, q_embed_7, attn_output_28], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf194, arg1_1, arg0_1, arg2_1, buf202, 2048, stream=stream0)
        buf203 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_50.run(buf201, buf203, 2048, 8192, stream=stream0)
        buf204 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_21, query_states_29, q_7, chunk_14, mul_65, neg_14, cat_14, mul_66, q_embed_7, attn_output_28], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf202, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf203, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf204)
        buf208 = reinterpret_tensor(buf204, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf204  # reuse
        # Topologically Sorted Source Nodes: [attn_output_28, arange_7, attn_mask_7], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf208, arg1_1, 16, 8192, stream=stream0)
        buf209 = reinterpret_tensor(buf203, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf203  # reuse
        # Topologically Sorted Source Nodes: [setitem_15, attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_51.run(buf201, buf209, 16777216, stream=stream0)
        buf210 = reinterpret_tensor(buf202, (16, 1, 128), (128, 128, 1), 0); del buf202  # reuse
        # Topologically Sorted Source Nodes: [attn_output_28, arange_7, attn_mask_7, setitem_15], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf208, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf209, (16, 8192, 128), (1048576, 128, 1), 0), out=buf210)
        del buf208
        del buf209
        buf211 = reinterpret_tensor(buf194, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf194  # reuse
        # Topologically Sorted Source Nodes: [attn_output_28], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf210, buf211, 2048, stream=stream0)
        buf212 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_28, transpose_31, attn_output_30, attn_output_31], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf211, (1, 2048), (0, 1), 0), reinterpret_tensor(arg72_1, (2048, 2048), (1, 2048), 0), out=buf212)
        del arg72_1
        buf214 = reinterpret_tensor(buf211, (1, 2048), (2048, 1), 0); del buf211  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_31, hidden_states_34, hidden_states_36, to_62, pow_16, variance_15, add_46, rsqrt_15, mul_69, hidden_15, hidden_states_37], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf164, buf185, buf191, buf212, arg73_1, buf214, 1, 2048, stream=stream0)
        del arg73_1
        buf215 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_53], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf214, reinterpret_tensor(arg74_1, (2048, 6144), (1, 2048), 0), out=buf215)
        del arg74_1
        buf216 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_54], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf214, reinterpret_tensor(arg75_1, (2048, 6144), (1, 2048), 0), out=buf216)
        del arg75_1
        buf217 = buf215; del buf215  # reuse
        # Topologically Sorted Source Nodes: [silu_7, mul_71], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf217, buf216, 6144, stream=stream0)
        del buf216
        buf218 = buf214; del buf214  # reuse
        # Topologically Sorted Source Nodes: [silu_7, mul_71, hidden_states_38], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf217, reinterpret_tensor(arg76_1, (6144, 2048), (1, 6144), 0), out=buf218)
        del arg76_1
        del buf217
        buf219 = buf164; del buf164  # reuse
        buf221 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_31, hidden_states_34, hidden_states_36, hidden_states_39, to_64, pow_17, variance_16, add_48, rsqrt_16, mul_72, hidden_16, hidden_states_40], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf219, buf185, buf191, buf212, buf218, arg77_1, buf221, 1, 2048, stream=stream0)
        del arg77_1
        del buf185
        del buf191
        del buf212
        buf222 = buf218; del buf218  # reuse
        # Topologically Sorted Source Nodes: [query_states_32], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf221, reinterpret_tensor(arg78_1, (2048, 2048), (1, 2048), 0), out=buf222)
        del arg78_1
        buf223 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_24], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf221, reinterpret_tensor(arg79_1, (2048, 256), (1, 2048), 0), out=buf223)
        del arg79_1
        buf224 = buf199; del buf199  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_25, key_states_25, k_8, chunk_17, setitem_16, mul_76, neg_17, cat_17, mul_77, k_embed_8, key_states_26], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_52.run(buf201, buf224, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_25, key_states_25, k_8, chunk_17, setitem_16, mul_76, neg_17, cat_17, mul_77, k_embed_8, key_states_26], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf223, arg0_1, arg2_1, buf224, 256, stream=stream0)
        buf226 = buf223; del buf223  # reuse
        # Topologically Sorted Source Nodes: [value_states_16], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf221, reinterpret_tensor(arg80_1, (2048, 256), (1, 2048), 0), out=buf226)
        del arg80_1
        del buf221
        buf227 = buf196; del buf196  # reuse
        # Topologically Sorted Source Nodes: [setitem_17, view_26, value_states_17], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_53.run(buf224, buf201, buf227, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_17, view_26, value_states_17], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf226, buf227, 256, stream=stream0)
        del buf226
        buf229 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_54.run(buf227, buf224, buf201, buf229, 117440512, stream=stream0)
        del buf201
        buf230 = reinterpret_tensor(buf210, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf210  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_24, query_states_33, q_8, chunk_16, mul_74, neg_16, cat_16, mul_75, q_embed_8, attn_output_32], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf222, arg1_1, arg0_1, arg2_1, buf230, 2048, stream=stream0)
        buf231 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_32], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_55.run(buf229, buf231, 2048, 8192, stream=stream0)
        buf232 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_24, query_states_33, q_8, chunk_16, mul_74, neg_16, cat_16, mul_75, q_embed_8, attn_output_32], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf230, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf231, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf232)
        buf236 = reinterpret_tensor(buf232, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf232  # reuse
        # Topologically Sorted Source Nodes: [attn_output_32, arange_8, attn_mask_8], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf236, arg1_1, 16, 8192, stream=stream0)
        buf237 = reinterpret_tensor(buf231, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf231  # reuse
        # Topologically Sorted Source Nodes: [setitem_17, attn_output_32], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_56.run(buf229, buf237, 16777216, stream=stream0)
        buf238 = reinterpret_tensor(buf230, (16, 1, 128), (128, 128, 1), 0); del buf230  # reuse
        # Topologically Sorted Source Nodes: [attn_output_32, arange_8, attn_mask_8, setitem_17], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf236, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf237, (16, 8192, 128), (1048576, 128, 1), 0), out=buf238)
        del buf236
        del buf237
        buf239 = reinterpret_tensor(buf222, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf222  # reuse
        # Topologically Sorted Source Nodes: [attn_output_32], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf238, buf239, 2048, stream=stream0)
        del buf238
        buf240 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_32, transpose_35, attn_output_34, attn_output_35], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf239, (1, 2048), (0, 1), 0), reinterpret_tensor(arg81_1, (2048, 2048), (1, 2048), 0), out=buf240)
        del arg81_1
        buf242 = reinterpret_tensor(buf239, (1, 2048), (2048, 1), 0); del buf239  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_41, to_70, pow_18, variance_17, add_52, rsqrt_17, mul_78, hidden_17, hidden_states_42], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf219, buf240, arg82_1, buf242, 1, 2048, stream=stream0)
        del arg82_1
        buf243 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_60], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf242, reinterpret_tensor(arg83_1, (2048, 6144), (1, 2048), 0), out=buf243)
        del arg83_1
        buf244 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_61], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf242, reinterpret_tensor(arg84_1, (2048, 6144), (1, 2048), 0), out=buf244)
        del arg84_1
        buf245 = buf243; del buf243  # reuse
        # Topologically Sorted Source Nodes: [silu_8, mul_80], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf245, buf244, 6144, stream=stream0)
        del buf244
        buf246 = buf242; del buf242  # reuse
        # Topologically Sorted Source Nodes: [silu_8, mul_80, hidden_states_43], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf245, reinterpret_tensor(arg85_1, (6144, 2048), (1, 6144), 0), out=buf246)
        del arg85_1
        del buf245
        buf248 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_41, hidden_states_44, to_72, pow_19, variance_18, add_54, rsqrt_18, mul_81, hidden_18, hidden_states_45], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf219, buf240, buf246, arg86_1, buf248, 1, 2048, stream=stream0)
        del arg86_1
        buf249 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_36], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf248, reinterpret_tensor(arg87_1, (2048, 2048), (1, 2048), 0), out=buf249)
        del arg87_1
        buf250 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_27], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf248, reinterpret_tensor(arg88_1, (2048, 256), (1, 2048), 0), out=buf250)
        del arg88_1
        buf251 = buf227; del buf227  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_28, key_states_28, k_9, chunk_19, setitem_18, mul_85, neg_19, cat_19, mul_86, k_embed_9, key_states_29], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_57.run(buf229, buf251, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_28, key_states_28, k_9, chunk_19, setitem_18, mul_85, neg_19, cat_19, mul_86, k_embed_9, key_states_29], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf250, arg0_1, arg2_1, buf251, 256, stream=stream0)
        buf253 = buf250; del buf250  # reuse
        # Topologically Sorted Source Nodes: [value_states_18], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf248, reinterpret_tensor(arg89_1, (2048, 256), (1, 2048), 0), out=buf253)
        del arg89_1
        del buf248
        buf254 = buf224; del buf224  # reuse
        # Topologically Sorted Source Nodes: [setitem_19, view_29, value_states_19], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_58.run(buf251, buf229, buf254, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_19, view_29, value_states_19], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf253, buf254, 256, stream=stream0)
        del buf253
        buf256 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_59.run(buf254, buf251, buf229, buf256, 117440512, stream=stream0)
        del buf229
        buf257 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_27, query_states_37, q_9, chunk_18, mul_83, neg_18, cat_18, mul_84, q_embed_9, attn_output_36], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf249, arg1_1, arg0_1, arg2_1, buf257, 2048, stream=stream0)
        buf258 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_36], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_60.run(buf256, buf258, 2048, 8192, stream=stream0)
        buf259 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_27, query_states_37, q_9, chunk_18, mul_83, neg_18, cat_18, mul_84, q_embed_9, attn_output_36], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf257, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf258, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf259)
        buf263 = reinterpret_tensor(buf259, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf259  # reuse
        # Topologically Sorted Source Nodes: [attn_output_36, arange_9, attn_mask_9], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf263, arg1_1, 16, 8192, stream=stream0)
        buf264 = reinterpret_tensor(buf258, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf258  # reuse
        # Topologically Sorted Source Nodes: [setitem_19, attn_output_36], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_61.run(buf256, buf264, 16777216, stream=stream0)
        buf265 = reinterpret_tensor(buf257, (16, 1, 128), (128, 128, 1), 0); del buf257  # reuse
        # Topologically Sorted Source Nodes: [attn_output_36, arange_9, attn_mask_9, setitem_19], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf263, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf264, (16, 8192, 128), (1048576, 128, 1), 0), out=buf265)
        del buf263
        del buf264
        buf266 = reinterpret_tensor(buf249, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf249  # reuse
        # Topologically Sorted Source Nodes: [attn_output_36], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf265, buf266, 2048, stream=stream0)
        buf267 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_36, transpose_39, attn_output_38, attn_output_39], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf266, (1, 2048), (0, 1), 0), reinterpret_tensor(arg90_1, (2048, 2048), (1, 2048), 0), out=buf267)
        del arg90_1
        buf269 = reinterpret_tensor(buf266, (1, 2048), (2048, 1), 0); del buf266  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_41, hidden_states_44, hidden_states_46, to_78, pow_20, variance_19, add_58, rsqrt_19, mul_87, hidden_19, hidden_states_47], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf219, buf240, buf246, buf267, arg91_1, buf269, 1, 2048, stream=stream0)
        del arg91_1
        buf270 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_67], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf269, reinterpret_tensor(arg92_1, (2048, 6144), (1, 2048), 0), out=buf270)
        del arg92_1
        buf271 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_68], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf269, reinterpret_tensor(arg93_1, (2048, 6144), (1, 2048), 0), out=buf271)
        del arg93_1
        buf272 = buf270; del buf270  # reuse
        # Topologically Sorted Source Nodes: [silu_9, mul_89], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf272, buf271, 6144, stream=stream0)
        del buf271
        buf273 = buf269; del buf269  # reuse
        # Topologically Sorted Source Nodes: [silu_9, mul_89, hidden_states_48], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf272, reinterpret_tensor(arg94_1, (6144, 2048), (1, 6144), 0), out=buf273)
        del arg94_1
        del buf272
        buf274 = buf219; del buf219  # reuse
        buf276 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_41, hidden_states_44, hidden_states_46, hidden_states_49, to_80, pow_21, variance_20, add_60, rsqrt_20, mul_90, hidden_20, hidden_states_50], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf274, buf240, buf246, buf267, buf273, arg95_1, buf276, 1, 2048, stream=stream0)
        del arg95_1
        del buf240
        del buf246
        del buf267
        buf277 = buf273; del buf273  # reuse
        # Topologically Sorted Source Nodes: [query_states_40], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf276, reinterpret_tensor(arg96_1, (2048, 2048), (1, 2048), 0), out=buf277)
        del arg96_1
        buf278 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_30], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf276, reinterpret_tensor(arg97_1, (2048, 256), (1, 2048), 0), out=buf278)
        del arg97_1
        buf279 = buf254; del buf254  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_31, key_states_31, k_10, chunk_21, setitem_20, mul_94, neg_21, cat_21, mul_95, k_embed_10, key_states_32], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_62.run(buf256, buf279, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_31, key_states_31, k_10, chunk_21, setitem_20, mul_94, neg_21, cat_21, mul_95, k_embed_10, key_states_32], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf278, arg0_1, arg2_1, buf279, 256, stream=stream0)
        buf281 = buf278; del buf278  # reuse
        # Topologically Sorted Source Nodes: [value_states_20], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf276, reinterpret_tensor(arg98_1, (2048, 256), (1, 2048), 0), out=buf281)
        del arg98_1
        del buf276
        buf282 = buf251; del buf251  # reuse
        # Topologically Sorted Source Nodes: [setitem_21, view_32, value_states_21], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_63.run(buf279, buf256, buf282, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_21, view_32, value_states_21], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf281, buf282, 256, stream=stream0)
        del buf281
        buf284 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_64.run(buf282, buf279, buf256, buf284, 117440512, stream=stream0)
        del buf256
        buf285 = reinterpret_tensor(buf265, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf265  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_30, query_states_41, q_10, chunk_20, mul_92, neg_20, cat_20, mul_93, q_embed_10, attn_output_40], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf277, arg1_1, arg0_1, arg2_1, buf285, 2048, stream=stream0)
        buf286 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_40], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_65.run(buf284, buf286, 2048, 8192, stream=stream0)
        buf287 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_30, query_states_41, q_10, chunk_20, mul_92, neg_20, cat_20, mul_93, q_embed_10, attn_output_40], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf285, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf286, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf287)
        buf291 = reinterpret_tensor(buf287, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf287  # reuse
        # Topologically Sorted Source Nodes: [attn_output_40, arange_10, attn_mask_10], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf291, arg1_1, 16, 8192, stream=stream0)
        buf292 = reinterpret_tensor(buf286, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf286  # reuse
        # Topologically Sorted Source Nodes: [setitem_21, attn_output_40], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_66.run(buf284, buf292, 16777216, stream=stream0)
        buf293 = reinterpret_tensor(buf285, (16, 1, 128), (128, 128, 1), 0); del buf285  # reuse
        # Topologically Sorted Source Nodes: [attn_output_40, arange_10, attn_mask_10, setitem_21], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf291, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf292, (16, 8192, 128), (1048576, 128, 1), 0), out=buf293)
        del buf291
        del buf292
        buf294 = reinterpret_tensor(buf277, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf277  # reuse
        # Topologically Sorted Source Nodes: [attn_output_40], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf293, buf294, 2048, stream=stream0)
        del buf293
        buf295 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_40, transpose_43, attn_output_42, attn_output_43], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf294, (1, 2048), (0, 1), 0), reinterpret_tensor(arg99_1, (2048, 2048), (1, 2048), 0), out=buf295)
        del arg99_1
        buf297 = reinterpret_tensor(buf294, (1, 2048), (2048, 1), 0); del buf294  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_51, to_86, pow_22, variance_21, add_64, rsqrt_21, mul_96, hidden_21, hidden_states_52], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf274, buf295, arg100_1, buf297, 1, 2048, stream=stream0)
        del arg100_1
        buf298 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_74], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf297, reinterpret_tensor(arg101_1, (2048, 6144), (1, 2048), 0), out=buf298)
        del arg101_1
        buf299 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_75], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf297, reinterpret_tensor(arg102_1, (2048, 6144), (1, 2048), 0), out=buf299)
        del arg102_1
        buf300 = buf298; del buf298  # reuse
        # Topologically Sorted Source Nodes: [silu_10, mul_98], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf300, buf299, 6144, stream=stream0)
        del buf299
        buf301 = buf297; del buf297  # reuse
        # Topologically Sorted Source Nodes: [silu_10, mul_98, hidden_states_53], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf300, reinterpret_tensor(arg103_1, (6144, 2048), (1, 6144), 0), out=buf301)
        del arg103_1
        del buf300
        buf303 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_51, hidden_states_54, to_88, pow_23, variance_22, add_66, rsqrt_22, mul_99, hidden_22, hidden_states_55], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf274, buf295, buf301, arg104_1, buf303, 1, 2048, stream=stream0)
        del arg104_1
        buf304 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_44], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf303, reinterpret_tensor(arg105_1, (2048, 2048), (1, 2048), 0), out=buf304)
        del arg105_1
        buf305 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_33], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf303, reinterpret_tensor(arg106_1, (2048, 256), (1, 2048), 0), out=buf305)
        del arg106_1
        buf306 = buf282; del buf282  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_34, key_states_34, k_11, chunk_23, setitem_22, mul_103, neg_23, cat_23, mul_104, k_embed_11, key_states_35], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_67.run(buf284, buf306, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_34, key_states_34, k_11, chunk_23, setitem_22, mul_103, neg_23, cat_23, mul_104, k_embed_11, key_states_35], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf305, arg0_1, arg2_1, buf306, 256, stream=stream0)
        buf308 = buf305; del buf305  # reuse
        # Topologically Sorted Source Nodes: [value_states_22], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf303, reinterpret_tensor(arg107_1, (2048, 256), (1, 2048), 0), out=buf308)
        del arg107_1
        del buf303
        buf309 = buf279; del buf279  # reuse
        # Topologically Sorted Source Nodes: [setitem_23, view_35, value_states_23], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_68.run(buf306, buf284, buf309, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_23, view_35, value_states_23], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf308, buf309, 256, stream=stream0)
        del buf308
        buf311 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_69.run(buf309, buf306, buf284, buf311, 117440512, stream=stream0)
        del buf284
        buf312 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_33, query_states_45, q_11, chunk_22, mul_101, neg_22, cat_22, mul_102, q_embed_11, attn_output_44], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf304, arg1_1, arg0_1, arg2_1, buf312, 2048, stream=stream0)
        buf313 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_44], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_70.run(buf311, buf313, 2048, 8192, stream=stream0)
        buf314 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_33, query_states_45, q_11, chunk_22, mul_101, neg_22, cat_22, mul_102, q_embed_11, attn_output_44], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf312, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf313, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf314)
        buf318 = reinterpret_tensor(buf314, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf314  # reuse
        # Topologically Sorted Source Nodes: [attn_output_44, arange_11, attn_mask_11], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf318, arg1_1, 16, 8192, stream=stream0)
        buf319 = reinterpret_tensor(buf313, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf313  # reuse
        # Topologically Sorted Source Nodes: [setitem_23, attn_output_44], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_71.run(buf311, buf319, 16777216, stream=stream0)
        buf320 = reinterpret_tensor(buf312, (16, 1, 128), (128, 128, 1), 0); del buf312  # reuse
        # Topologically Sorted Source Nodes: [attn_output_44, arange_11, attn_mask_11, setitem_23], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf318, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf319, (16, 8192, 128), (1048576, 128, 1), 0), out=buf320)
        del buf318
        del buf319
        buf321 = reinterpret_tensor(buf304, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf304  # reuse
        # Topologically Sorted Source Nodes: [attn_output_44], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf320, buf321, 2048, stream=stream0)
        buf322 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_44, transpose_47, attn_output_46, attn_output_47], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf321, (1, 2048), (0, 1), 0), reinterpret_tensor(arg108_1, (2048, 2048), (1, 2048), 0), out=buf322)
        del arg108_1
        buf324 = reinterpret_tensor(buf321, (1, 2048), (2048, 1), 0); del buf321  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_51, hidden_states_54, hidden_states_56, to_94, pow_24, variance_23, add_70, rsqrt_23, mul_105, hidden_23, hidden_states_57], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf274, buf295, buf301, buf322, arg109_1, buf324, 1, 2048, stream=stream0)
        del arg109_1
        buf325 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_81], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf324, reinterpret_tensor(arg110_1, (2048, 6144), (1, 2048), 0), out=buf325)
        del arg110_1
        buf326 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_82], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf324, reinterpret_tensor(arg111_1, (2048, 6144), (1, 2048), 0), out=buf326)
        del arg111_1
        buf327 = buf325; del buf325  # reuse
        # Topologically Sorted Source Nodes: [silu_11, mul_107], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf327, buf326, 6144, stream=stream0)
        del buf326
        buf328 = buf324; del buf324  # reuse
        # Topologically Sorted Source Nodes: [silu_11, mul_107, hidden_states_58], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf327, reinterpret_tensor(arg112_1, (6144, 2048), (1, 6144), 0), out=buf328)
        del arg112_1
        del buf327
        buf329 = buf274; del buf274  # reuse
        buf331 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_51, hidden_states_54, hidden_states_56, hidden_states_59, to_96, pow_25, variance_24, add_72, rsqrt_24, mul_108, hidden_24, hidden_states_60], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf329, buf295, buf301, buf322, buf328, arg113_1, buf331, 1, 2048, stream=stream0)
        del arg113_1
        del buf295
        del buf301
        del buf322
        buf332 = buf328; del buf328  # reuse
        # Topologically Sorted Source Nodes: [query_states_48], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf331, reinterpret_tensor(arg114_1, (2048, 2048), (1, 2048), 0), out=buf332)
        del arg114_1
        buf333 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_36], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf331, reinterpret_tensor(arg115_1, (2048, 256), (1, 2048), 0), out=buf333)
        del arg115_1
        buf334 = buf309; del buf309  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_37, key_states_37, k_12, chunk_25, setitem_24, mul_112, neg_25, cat_25, mul_113, k_embed_12, key_states_38], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_72.run(buf311, buf334, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_37, key_states_37, k_12, chunk_25, setitem_24, mul_112, neg_25, cat_25, mul_113, k_embed_12, key_states_38], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf333, arg0_1, arg2_1, buf334, 256, stream=stream0)
        buf336 = buf333; del buf333  # reuse
        # Topologically Sorted Source Nodes: [value_states_24], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf331, reinterpret_tensor(arg116_1, (2048, 256), (1, 2048), 0), out=buf336)
        del arg116_1
        del buf331
        buf337 = buf306; del buf306  # reuse
        # Topologically Sorted Source Nodes: [setitem_25, view_38, value_states_25], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_73.run(buf334, buf311, buf337, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_25, view_38, value_states_25], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf336, buf337, 256, stream=stream0)
        del buf336
        buf339 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_74.run(buf337, buf334, buf311, buf339, 117440512, stream=stream0)
        del buf311
        buf340 = reinterpret_tensor(buf320, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf320  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_36, query_states_49, q_12, chunk_24, mul_110, neg_24, cat_24, mul_111, q_embed_12, attn_output_48], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf332, arg1_1, arg0_1, arg2_1, buf340, 2048, stream=stream0)
        buf341 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_48], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_75.run(buf339, buf341, 2048, 8192, stream=stream0)
        buf342 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_36, query_states_49, q_12, chunk_24, mul_110, neg_24, cat_24, mul_111, q_embed_12, attn_output_48], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf340, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf341, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf342)
        buf346 = reinterpret_tensor(buf342, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf342  # reuse
        # Topologically Sorted Source Nodes: [attn_output_48, arange_12, attn_mask_12], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf346, arg1_1, 16, 8192, stream=stream0)
        buf347 = reinterpret_tensor(buf341, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf341  # reuse
        # Topologically Sorted Source Nodes: [setitem_25, attn_output_48], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_76.run(buf339, buf347, 16777216, stream=stream0)
        buf348 = reinterpret_tensor(buf340, (16, 1, 128), (128, 128, 1), 0); del buf340  # reuse
        # Topologically Sorted Source Nodes: [attn_output_48, arange_12, attn_mask_12, setitem_25], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf346, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf347, (16, 8192, 128), (1048576, 128, 1), 0), out=buf348)
        del buf346
        del buf347
        buf349 = reinterpret_tensor(buf332, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf332  # reuse
        # Topologically Sorted Source Nodes: [attn_output_48], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf348, buf349, 2048, stream=stream0)
        del buf348
        buf350 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_48, transpose_51, attn_output_50, attn_output_51], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf349, (1, 2048), (0, 1), 0), reinterpret_tensor(arg117_1, (2048, 2048), (1, 2048), 0), out=buf350)
        del arg117_1
        buf352 = reinterpret_tensor(buf349, (1, 2048), (2048, 1), 0); del buf349  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_61, to_102, pow_26, variance_25, add_76, rsqrt_25, mul_114, hidden_25, hidden_states_62], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf329, buf350, arg118_1, buf352, 1, 2048, stream=stream0)
        del arg118_1
        buf353 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_88], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf352, reinterpret_tensor(arg119_1, (2048, 6144), (1, 2048), 0), out=buf353)
        del arg119_1
        buf354 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_89], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf352, reinterpret_tensor(arg120_1, (2048, 6144), (1, 2048), 0), out=buf354)
        del arg120_1
        buf355 = buf353; del buf353  # reuse
        # Topologically Sorted Source Nodes: [silu_12, mul_116], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf355, buf354, 6144, stream=stream0)
        del buf354
        buf356 = buf352; del buf352  # reuse
        # Topologically Sorted Source Nodes: [silu_12, mul_116, hidden_states_63], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf355, reinterpret_tensor(arg121_1, (6144, 2048), (1, 6144), 0), out=buf356)
        del arg121_1
        del buf355
        buf358 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_61, hidden_states_64, to_104, pow_27, variance_26, add_78, rsqrt_26, mul_117, hidden_26, hidden_states_65], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf329, buf350, buf356, arg122_1, buf358, 1, 2048, stream=stream0)
        del arg122_1
        buf359 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_52], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf358, reinterpret_tensor(arg123_1, (2048, 2048), (1, 2048), 0), out=buf359)
        del arg123_1
        buf360 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_39], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf358, reinterpret_tensor(arg124_1, (2048, 256), (1, 2048), 0), out=buf360)
        del arg124_1
        buf361 = buf337; del buf337  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_40, key_states_40, k_13, chunk_27, setitem_26, mul_121, neg_27, cat_27, mul_122, k_embed_13, key_states_41], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_77.run(buf339, buf361, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_40, key_states_40, k_13, chunk_27, setitem_26, mul_121, neg_27, cat_27, mul_122, k_embed_13, key_states_41], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf360, arg0_1, arg2_1, buf361, 256, stream=stream0)
        buf363 = buf360; del buf360  # reuse
        # Topologically Sorted Source Nodes: [value_states_26], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf358, reinterpret_tensor(arg125_1, (2048, 256), (1, 2048), 0), out=buf363)
        del arg125_1
        del buf358
        buf364 = buf334; del buf334  # reuse
        # Topologically Sorted Source Nodes: [setitem_27, view_41, value_states_27], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_78.run(buf361, buf339, buf364, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_27, view_41, value_states_27], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf363, buf364, 256, stream=stream0)
        del buf363
        buf366 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_79.run(buf364, buf361, buf339, buf366, 117440512, stream=stream0)
        del buf339
        buf367 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_39, query_states_53, q_13, chunk_26, mul_119, neg_26, cat_26, mul_120, q_embed_13, attn_output_52], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf359, arg1_1, arg0_1, arg2_1, buf367, 2048, stream=stream0)
        buf368 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_52], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_80.run(buf366, buf368, 2048, 8192, stream=stream0)
        buf369 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_39, query_states_53, q_13, chunk_26, mul_119, neg_26, cat_26, mul_120, q_embed_13, attn_output_52], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf367, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf368, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf369)
        buf373 = reinterpret_tensor(buf369, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf369  # reuse
        # Topologically Sorted Source Nodes: [attn_output_52, arange_13, attn_mask_13], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf373, arg1_1, 16, 8192, stream=stream0)
        buf374 = reinterpret_tensor(buf368, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf368  # reuse
        # Topologically Sorted Source Nodes: [setitem_27, attn_output_52], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_81.run(buf366, buf374, 16777216, stream=stream0)
        buf375 = reinterpret_tensor(buf367, (16, 1, 128), (128, 128, 1), 0); del buf367  # reuse
        # Topologically Sorted Source Nodes: [attn_output_52, arange_13, attn_mask_13, setitem_27], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf373, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf374, (16, 8192, 128), (1048576, 128, 1), 0), out=buf375)
        del buf373
        del buf374
        buf376 = reinterpret_tensor(buf359, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf359  # reuse
        # Topologically Sorted Source Nodes: [attn_output_52], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf375, buf376, 2048, stream=stream0)
        buf377 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_52, transpose_55, attn_output_54, attn_output_55], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf376, (1, 2048), (0, 1), 0), reinterpret_tensor(arg126_1, (2048, 2048), (1, 2048), 0), out=buf377)
        del arg126_1
        buf379 = reinterpret_tensor(buf376, (1, 2048), (2048, 1), 0); del buf376  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_61, hidden_states_64, hidden_states_66, to_110, pow_28, variance_27, add_82, rsqrt_27, mul_123, hidden_27, hidden_states_67], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf329, buf350, buf356, buf377, arg127_1, buf379, 1, 2048, stream=stream0)
        del arg127_1
        buf380 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_95], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf379, reinterpret_tensor(arg128_1, (2048, 6144), (1, 2048), 0), out=buf380)
        del arg128_1
        buf381 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_96], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf379, reinterpret_tensor(arg129_1, (2048, 6144), (1, 2048), 0), out=buf381)
        del arg129_1
        buf382 = buf380; del buf380  # reuse
        # Topologically Sorted Source Nodes: [silu_13, mul_125], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf382, buf381, 6144, stream=stream0)
        del buf381
        buf383 = buf379; del buf379  # reuse
        # Topologically Sorted Source Nodes: [silu_13, mul_125, hidden_states_68], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf382, reinterpret_tensor(arg130_1, (6144, 2048), (1, 6144), 0), out=buf383)
        del arg130_1
        del buf382
        buf384 = buf329; del buf329  # reuse
        buf386 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_61, hidden_states_64, hidden_states_66, hidden_states_69, to_112, pow_29, variance_28, add_84, rsqrt_28, mul_126, hidden_28, hidden_states_70], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf384, buf350, buf356, buf377, buf383, arg131_1, buf386, 1, 2048, stream=stream0)
        del arg131_1
        del buf350
        del buf356
        del buf377
        buf387 = buf383; del buf383  # reuse
        # Topologically Sorted Source Nodes: [query_states_56], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf386, reinterpret_tensor(arg132_1, (2048, 2048), (1, 2048), 0), out=buf387)
        del arg132_1
        buf388 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_42], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf386, reinterpret_tensor(arg133_1, (2048, 256), (1, 2048), 0), out=buf388)
        del arg133_1
        buf389 = buf364; del buf364  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_43, key_states_43, k_14, chunk_29, setitem_28, mul_130, neg_29, cat_29, mul_131, k_embed_14, key_states_44], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_82.run(buf366, buf389, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_43, key_states_43, k_14, chunk_29, setitem_28, mul_130, neg_29, cat_29, mul_131, k_embed_14, key_states_44], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf388, arg0_1, arg2_1, buf389, 256, stream=stream0)
        buf391 = buf388; del buf388  # reuse
        # Topologically Sorted Source Nodes: [value_states_28], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf386, reinterpret_tensor(arg134_1, (2048, 256), (1, 2048), 0), out=buf391)
        del arg134_1
        del buf386
        buf392 = buf361; del buf361  # reuse
        # Topologically Sorted Source Nodes: [setitem_29, view_44, value_states_29], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_83.run(buf389, buf366, buf392, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_29, view_44, value_states_29], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf391, buf392, 256, stream=stream0)
        del buf391
        buf394 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_84.run(buf392, buf389, buf366, buf394, 117440512, stream=stream0)
        del buf366
        buf395 = reinterpret_tensor(buf375, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf375  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_42, query_states_57, q_14, chunk_28, mul_128, neg_28, cat_28, mul_129, q_embed_14, attn_output_56], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf387, arg1_1, arg0_1, arg2_1, buf395, 2048, stream=stream0)
        buf396 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_56], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_85.run(buf394, buf396, 2048, 8192, stream=stream0)
        buf397 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_42, query_states_57, q_14, chunk_28, mul_128, neg_28, cat_28, mul_129, q_embed_14, attn_output_56], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf395, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf396, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf397)
        buf401 = reinterpret_tensor(buf397, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf397  # reuse
        # Topologically Sorted Source Nodes: [attn_output_56, arange_14, attn_mask_14], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf401, arg1_1, 16, 8192, stream=stream0)
        buf402 = reinterpret_tensor(buf396, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf396  # reuse
        # Topologically Sorted Source Nodes: [setitem_29, attn_output_56], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_86.run(buf394, buf402, 16777216, stream=stream0)
        buf403 = reinterpret_tensor(buf395, (16, 1, 128), (128, 128, 1), 0); del buf395  # reuse
        # Topologically Sorted Source Nodes: [attn_output_56, arange_14, attn_mask_14, setitem_29], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf401, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf402, (16, 8192, 128), (1048576, 128, 1), 0), out=buf403)
        del buf401
        del buf402
        buf404 = reinterpret_tensor(buf387, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf387  # reuse
        # Topologically Sorted Source Nodes: [attn_output_56], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf403, buf404, 2048, stream=stream0)
        del buf403
        buf405 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_56, transpose_59, attn_output_58, attn_output_59], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf404, (1, 2048), (0, 1), 0), reinterpret_tensor(arg135_1, (2048, 2048), (1, 2048), 0), out=buf405)
        del arg135_1
        buf407 = reinterpret_tensor(buf404, (1, 2048), (2048, 1), 0); del buf404  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_71, to_118, pow_30, variance_29, add_88, rsqrt_29, mul_132, hidden_29, hidden_states_72], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf384, buf405, arg136_1, buf407, 1, 2048, stream=stream0)
        del arg136_1
        buf408 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_102], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf407, reinterpret_tensor(arg137_1, (2048, 6144), (1, 2048), 0), out=buf408)
        del arg137_1
        buf409 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_103], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf407, reinterpret_tensor(arg138_1, (2048, 6144), (1, 2048), 0), out=buf409)
        del arg138_1
        buf410 = buf408; del buf408  # reuse
        # Topologically Sorted Source Nodes: [silu_14, mul_134], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf410, buf409, 6144, stream=stream0)
        del buf409
        buf411 = buf407; del buf407  # reuse
        # Topologically Sorted Source Nodes: [silu_14, mul_134, hidden_states_73], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf410, reinterpret_tensor(arg139_1, (6144, 2048), (1, 6144), 0), out=buf411)
        del arg139_1
        del buf410
        buf413 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_71, hidden_states_74, to_120, pow_31, variance_30, add_90, rsqrt_30, mul_135, hidden_30, hidden_states_75], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf384, buf405, buf411, arg140_1, buf413, 1, 2048, stream=stream0)
        del arg140_1
        buf414 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_60], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf413, reinterpret_tensor(arg141_1, (2048, 2048), (1, 2048), 0), out=buf414)
        del arg141_1
        buf415 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_45], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf413, reinterpret_tensor(arg142_1, (2048, 256), (1, 2048), 0), out=buf415)
        del arg142_1
        buf416 = buf392; del buf392  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_46, key_states_46, k_15, chunk_31, setitem_30, mul_139, neg_31, cat_31, mul_140, k_embed_15, key_states_47], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_87.run(buf394, buf416, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_46, key_states_46, k_15, chunk_31, setitem_30, mul_139, neg_31, cat_31, mul_140, k_embed_15, key_states_47], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf415, arg0_1, arg2_1, buf416, 256, stream=stream0)
        buf418 = buf415; del buf415  # reuse
        # Topologically Sorted Source Nodes: [value_states_30], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf413, reinterpret_tensor(arg143_1, (2048, 256), (1, 2048), 0), out=buf418)
        del arg143_1
        del buf413
        buf419 = buf389; del buf389  # reuse
        # Topologically Sorted Source Nodes: [setitem_31, view_47, value_states_31], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_88.run(buf416, buf394, buf419, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_31, view_47, value_states_31], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf418, buf419, 256, stream=stream0)
        del buf418
        buf421 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_89.run(buf419, buf416, buf394, buf421, 117440512, stream=stream0)
        del buf394
        buf422 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_45, query_states_61, q_15, chunk_30, mul_137, neg_30, cat_30, mul_138, q_embed_15, attn_output_60], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf414, arg1_1, arg0_1, arg2_1, buf422, 2048, stream=stream0)
        buf423 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_60], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_90.run(buf421, buf423, 2048, 8192, stream=stream0)
        buf424 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_45, query_states_61, q_15, chunk_30, mul_137, neg_30, cat_30, mul_138, q_embed_15, attn_output_60], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf422, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf423, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf424)
        buf428 = reinterpret_tensor(buf424, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf424  # reuse
        # Topologically Sorted Source Nodes: [attn_output_60, arange_15, attn_mask_15], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf428, arg1_1, 16, 8192, stream=stream0)
        buf429 = reinterpret_tensor(buf423, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf423  # reuse
        # Topologically Sorted Source Nodes: [setitem_31, attn_output_60], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_91.run(buf421, buf429, 16777216, stream=stream0)
        buf430 = reinterpret_tensor(buf422, (16, 1, 128), (128, 128, 1), 0); del buf422  # reuse
        # Topologically Sorted Source Nodes: [attn_output_60, arange_15, attn_mask_15, setitem_31], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf428, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf429, (16, 8192, 128), (1048576, 128, 1), 0), out=buf430)
        del buf428
        del buf429
        buf431 = reinterpret_tensor(buf414, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf414  # reuse
        # Topologically Sorted Source Nodes: [attn_output_60], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf430, buf431, 2048, stream=stream0)
        buf432 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_60, transpose_63, attn_output_62, attn_output_63], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf431, (1, 2048), (0, 1), 0), reinterpret_tensor(arg144_1, (2048, 2048), (1, 2048), 0), out=buf432)
        del arg144_1
        buf434 = reinterpret_tensor(buf431, (1, 2048), (2048, 1), 0); del buf431  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_71, hidden_states_74, hidden_states_76, to_126, pow_32, variance_31, add_94, rsqrt_31, mul_141, hidden_31, hidden_states_77], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf384, buf405, buf411, buf432, arg145_1, buf434, 1, 2048, stream=stream0)
        del arg145_1
        buf435 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_109], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf434, reinterpret_tensor(arg146_1, (2048, 6144), (1, 2048), 0), out=buf435)
        del arg146_1
        buf436 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_110], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf434, reinterpret_tensor(arg147_1, (2048, 6144), (1, 2048), 0), out=buf436)
        del arg147_1
        buf437 = buf435; del buf435  # reuse
        # Topologically Sorted Source Nodes: [silu_15, mul_143], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf437, buf436, 6144, stream=stream0)
        del buf436
        buf438 = buf434; del buf434  # reuse
        # Topologically Sorted Source Nodes: [silu_15, mul_143, hidden_states_78], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf437, reinterpret_tensor(arg148_1, (6144, 2048), (1, 6144), 0), out=buf438)
        del arg148_1
        del buf437
        buf439 = buf384; del buf384  # reuse
        buf441 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_71, hidden_states_74, hidden_states_76, hidden_states_79, to_128, pow_33, variance_32, add_96, rsqrt_32, mul_144, hidden_32, hidden_states_80], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf439, buf405, buf411, buf432, buf438, arg149_1, buf441, 1, 2048, stream=stream0)
        del arg149_1
        del buf405
        del buf411
        del buf432
        buf442 = buf438; del buf438  # reuse
        # Topologically Sorted Source Nodes: [query_states_64], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf441, reinterpret_tensor(arg150_1, (2048, 2048), (1, 2048), 0), out=buf442)
        del arg150_1
        buf443 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_48], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf441, reinterpret_tensor(arg151_1, (2048, 256), (1, 2048), 0), out=buf443)
        del arg151_1
        buf444 = buf419; del buf419  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_49, key_states_49, k_16, chunk_33, setitem_32, mul_148, neg_33, cat_33, mul_149, k_embed_16, key_states_50], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_92.run(buf421, buf444, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_49, key_states_49, k_16, chunk_33, setitem_32, mul_148, neg_33, cat_33, mul_149, k_embed_16, key_states_50], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf443, arg0_1, arg2_1, buf444, 256, stream=stream0)
        buf446 = buf443; del buf443  # reuse
        # Topologically Sorted Source Nodes: [value_states_32], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf441, reinterpret_tensor(arg152_1, (2048, 256), (1, 2048), 0), out=buf446)
        del arg152_1
        del buf441
        buf447 = buf416; del buf416  # reuse
        # Topologically Sorted Source Nodes: [setitem_33, view_50, value_states_33], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_93.run(buf444, buf421, buf447, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_33, view_50, value_states_33], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf446, buf447, 256, stream=stream0)
        del buf446
        buf449 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_94.run(buf447, buf444, buf421, buf449, 117440512, stream=stream0)
        del buf421
        buf450 = reinterpret_tensor(buf430, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf430  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_48, query_states_65, q_16, chunk_32, mul_146, neg_32, cat_32, mul_147, q_embed_16, attn_output_64], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf442, arg1_1, arg0_1, arg2_1, buf450, 2048, stream=stream0)
        buf451 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_64], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_95.run(buf449, buf451, 2048, 8192, stream=stream0)
        buf452 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_48, query_states_65, q_16, chunk_32, mul_146, neg_32, cat_32, mul_147, q_embed_16, attn_output_64], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf450, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf451, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf452)
        buf456 = reinterpret_tensor(buf452, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf452  # reuse
        # Topologically Sorted Source Nodes: [attn_output_64, arange_16, attn_mask_16], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf456, arg1_1, 16, 8192, stream=stream0)
        buf457 = reinterpret_tensor(buf451, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf451  # reuse
        # Topologically Sorted Source Nodes: [setitem_33, attn_output_64], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_96.run(buf449, buf457, 16777216, stream=stream0)
        buf458 = reinterpret_tensor(buf450, (16, 1, 128), (128, 128, 1), 0); del buf450  # reuse
        # Topologically Sorted Source Nodes: [attn_output_64, arange_16, attn_mask_16, setitem_33], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf456, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf457, (16, 8192, 128), (1048576, 128, 1), 0), out=buf458)
        del buf456
        del buf457
        buf459 = reinterpret_tensor(buf442, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf442  # reuse
        # Topologically Sorted Source Nodes: [attn_output_64], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf458, buf459, 2048, stream=stream0)
        del buf458
        buf460 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_64, transpose_67, attn_output_66, attn_output_67], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf459, (1, 2048), (0, 1), 0), reinterpret_tensor(arg153_1, (2048, 2048), (1, 2048), 0), out=buf460)
        del arg153_1
        buf462 = reinterpret_tensor(buf459, (1, 2048), (2048, 1), 0); del buf459  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_81, to_134, pow_34, variance_33, add_100, rsqrt_33, mul_150, hidden_33, hidden_states_82], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf439, buf460, arg154_1, buf462, 1, 2048, stream=stream0)
        del arg154_1
        buf463 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_116], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf462, reinterpret_tensor(arg155_1, (2048, 6144), (1, 2048), 0), out=buf463)
        del arg155_1
        buf464 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_117], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf462, reinterpret_tensor(arg156_1, (2048, 6144), (1, 2048), 0), out=buf464)
        del arg156_1
        buf465 = buf463; del buf463  # reuse
        # Topologically Sorted Source Nodes: [silu_16, mul_152], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf465, buf464, 6144, stream=stream0)
        del buf464
        buf466 = buf462; del buf462  # reuse
        # Topologically Sorted Source Nodes: [silu_16, mul_152, hidden_states_83], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf465, reinterpret_tensor(arg157_1, (6144, 2048), (1, 6144), 0), out=buf466)
        del arg157_1
        del buf465
        buf468 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_81, hidden_states_84, to_136, pow_35, variance_34, add_102, rsqrt_34, mul_153, hidden_34, hidden_states_85], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf439, buf460, buf466, arg158_1, buf468, 1, 2048, stream=stream0)
        del arg158_1
        buf469 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_68], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf468, reinterpret_tensor(arg159_1, (2048, 2048), (1, 2048), 0), out=buf469)
        del arg159_1
        buf470 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_51], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf468, reinterpret_tensor(arg160_1, (2048, 256), (1, 2048), 0), out=buf470)
        del arg160_1
        buf471 = buf447; del buf447  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_52, key_states_52, k_17, chunk_35, setitem_34, mul_157, neg_35, cat_35, mul_158, k_embed_17, key_states_53], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_97.run(buf449, buf471, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_52, key_states_52, k_17, chunk_35, setitem_34, mul_157, neg_35, cat_35, mul_158, k_embed_17, key_states_53], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf470, arg0_1, arg2_1, buf471, 256, stream=stream0)
        buf473 = buf470; del buf470  # reuse
        # Topologically Sorted Source Nodes: [value_states_34], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf468, reinterpret_tensor(arg161_1, (2048, 256), (1, 2048), 0), out=buf473)
        del arg161_1
        del buf468
        buf474 = buf444; del buf444  # reuse
        # Topologically Sorted Source Nodes: [setitem_35, view_53, value_states_35], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_98.run(buf471, buf449, buf474, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_35, view_53, value_states_35], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf473, buf474, 256, stream=stream0)
        del buf473
        buf476 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_99.run(buf474, buf471, buf449, buf476, 117440512, stream=stream0)
        del buf449
        buf477 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_51, query_states_69, q_17, chunk_34, mul_155, neg_34, cat_34, mul_156, q_embed_17, attn_output_68], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf469, arg1_1, arg0_1, arg2_1, buf477, 2048, stream=stream0)
        buf478 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_68], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_100.run(buf476, buf478, 2048, 8192, stream=stream0)
        buf479 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_51, query_states_69, q_17, chunk_34, mul_155, neg_34, cat_34, mul_156, q_embed_17, attn_output_68], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf477, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf478, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf479)
        buf483 = reinterpret_tensor(buf479, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf479  # reuse
        # Topologically Sorted Source Nodes: [attn_output_68, arange_17, attn_mask_17], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf483, arg1_1, 16, 8192, stream=stream0)
        buf484 = reinterpret_tensor(buf478, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf478  # reuse
        # Topologically Sorted Source Nodes: [setitem_35, attn_output_68], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_101.run(buf476, buf484, 16777216, stream=stream0)
        buf485 = reinterpret_tensor(buf477, (16, 1, 128), (128, 128, 1), 0); del buf477  # reuse
        # Topologically Sorted Source Nodes: [attn_output_68, arange_17, attn_mask_17, setitem_35], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf483, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf484, (16, 8192, 128), (1048576, 128, 1), 0), out=buf485)
        del buf483
        del buf484
        buf486 = reinterpret_tensor(buf469, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf469  # reuse
        # Topologically Sorted Source Nodes: [attn_output_68], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf485, buf486, 2048, stream=stream0)
        buf487 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_68, transpose_71, attn_output_70, attn_output_71], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf486, (1, 2048), (0, 1), 0), reinterpret_tensor(arg162_1, (2048, 2048), (1, 2048), 0), out=buf487)
        del arg162_1
        buf489 = reinterpret_tensor(buf486, (1, 2048), (2048, 1), 0); del buf486  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_81, hidden_states_84, hidden_states_86, to_142, pow_36, variance_35, add_106, rsqrt_35, mul_159, hidden_35, hidden_states_87], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf439, buf460, buf466, buf487, arg163_1, buf489, 1, 2048, stream=stream0)
        del arg163_1
        buf490 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_123], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf489, reinterpret_tensor(arg164_1, (2048, 6144), (1, 2048), 0), out=buf490)
        del arg164_1
        buf491 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_124], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf489, reinterpret_tensor(arg165_1, (2048, 6144), (1, 2048), 0), out=buf491)
        del arg165_1
        buf492 = buf490; del buf490  # reuse
        # Topologically Sorted Source Nodes: [silu_17, mul_161], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf492, buf491, 6144, stream=stream0)
        del buf491
        buf493 = buf489; del buf489  # reuse
        # Topologically Sorted Source Nodes: [silu_17, mul_161, hidden_states_88], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf492, reinterpret_tensor(arg166_1, (6144, 2048), (1, 6144), 0), out=buf493)
        del arg166_1
        del buf492
        buf494 = buf439; del buf439  # reuse
        buf496 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_81, hidden_states_84, hidden_states_86, hidden_states_89, to_144, pow_37, variance_36, add_108, rsqrt_36, mul_162, hidden_36, hidden_states_90], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf494, buf460, buf466, buf487, buf493, arg167_1, buf496, 1, 2048, stream=stream0)
        del arg167_1
        del buf460
        del buf466
        del buf487
        buf497 = buf493; del buf493  # reuse
        # Topologically Sorted Source Nodes: [query_states_72], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf496, reinterpret_tensor(arg168_1, (2048, 2048), (1, 2048), 0), out=buf497)
        del arg168_1
        buf498 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_54], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf496, reinterpret_tensor(arg169_1, (2048, 256), (1, 2048), 0), out=buf498)
        del arg169_1
        buf499 = buf474; del buf474  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_55, key_states_55, k_18, chunk_37, setitem_36, mul_166, neg_37, cat_37, mul_167, k_embed_18, key_states_56], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_102.run(buf476, buf499, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_55, key_states_55, k_18, chunk_37, setitem_36, mul_166, neg_37, cat_37, mul_167, k_embed_18, key_states_56], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf498, arg0_1, arg2_1, buf499, 256, stream=stream0)
        buf501 = buf498; del buf498  # reuse
        # Topologically Sorted Source Nodes: [value_states_36], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf496, reinterpret_tensor(arg170_1, (2048, 256), (1, 2048), 0), out=buf501)
        del arg170_1
        del buf496
        buf502 = buf471; del buf471  # reuse
        # Topologically Sorted Source Nodes: [setitem_37, view_56, value_states_37], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_103.run(buf499, buf476, buf502, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_37, view_56, value_states_37], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf501, buf502, 256, stream=stream0)
        del buf501
        buf504 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_104.run(buf502, buf499, buf476, buf504, 117440512, stream=stream0)
        del buf476
        buf505 = reinterpret_tensor(buf485, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf485  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_54, query_states_73, q_18, chunk_36, mul_164, neg_36, cat_36, mul_165, q_embed_18, attn_output_72], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf497, arg1_1, arg0_1, arg2_1, buf505, 2048, stream=stream0)
        buf506 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_72], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_105.run(buf504, buf506, 2048, 8192, stream=stream0)
        buf507 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_54, query_states_73, q_18, chunk_36, mul_164, neg_36, cat_36, mul_165, q_embed_18, attn_output_72], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf505, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf506, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf507)
        buf511 = reinterpret_tensor(buf507, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf507  # reuse
        # Topologically Sorted Source Nodes: [attn_output_72, arange_18, attn_mask_18], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf511, arg1_1, 16, 8192, stream=stream0)
        buf512 = reinterpret_tensor(buf506, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf506  # reuse
        # Topologically Sorted Source Nodes: [setitem_37, attn_output_72], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_106.run(buf504, buf512, 16777216, stream=stream0)
        buf513 = reinterpret_tensor(buf505, (16, 1, 128), (128, 128, 1), 0); del buf505  # reuse
        # Topologically Sorted Source Nodes: [attn_output_72, arange_18, attn_mask_18, setitem_37], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf511, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf512, (16, 8192, 128), (1048576, 128, 1), 0), out=buf513)
        del buf511
        del buf512
        buf514 = reinterpret_tensor(buf497, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf497  # reuse
        # Topologically Sorted Source Nodes: [attn_output_72], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf513, buf514, 2048, stream=stream0)
        del buf513
        buf515 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_72, transpose_75, attn_output_74, attn_output_75], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf514, (1, 2048), (0, 1), 0), reinterpret_tensor(arg171_1, (2048, 2048), (1, 2048), 0), out=buf515)
        del arg171_1
        buf517 = reinterpret_tensor(buf514, (1, 2048), (2048, 1), 0); del buf514  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_91, to_150, pow_38, variance_37, add_112, rsqrt_37, mul_168, hidden_37, hidden_states_92], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf494, buf515, arg172_1, buf517, 1, 2048, stream=stream0)
        del arg172_1
        buf518 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_130], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf517, reinterpret_tensor(arg173_1, (2048, 6144), (1, 2048), 0), out=buf518)
        del arg173_1
        buf519 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_131], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf517, reinterpret_tensor(arg174_1, (2048, 6144), (1, 2048), 0), out=buf519)
        del arg174_1
        buf520 = buf518; del buf518  # reuse
        # Topologically Sorted Source Nodes: [silu_18, mul_170], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf520, buf519, 6144, stream=stream0)
        del buf519
        buf521 = buf517; del buf517  # reuse
        # Topologically Sorted Source Nodes: [silu_18, mul_170, hidden_states_93], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf520, reinterpret_tensor(arg175_1, (6144, 2048), (1, 6144), 0), out=buf521)
        del arg175_1
        del buf520
        buf523 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_91, hidden_states_94, to_152, pow_39, variance_38, add_114, rsqrt_38, mul_171, hidden_38, hidden_states_95], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf494, buf515, buf521, arg176_1, buf523, 1, 2048, stream=stream0)
        del arg176_1
        buf524 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_76], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf523, reinterpret_tensor(arg177_1, (2048, 2048), (1, 2048), 0), out=buf524)
        del arg177_1
        buf525 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_57], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf523, reinterpret_tensor(arg178_1, (2048, 256), (1, 2048), 0), out=buf525)
        del arg178_1
        buf526 = buf502; del buf502  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_58, key_states_58, k_19, chunk_39, setitem_38, mul_175, neg_39, cat_39, mul_176, k_embed_19, key_states_59], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_107.run(buf504, buf526, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_58, key_states_58, k_19, chunk_39, setitem_38, mul_175, neg_39, cat_39, mul_176, k_embed_19, key_states_59], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf525, arg0_1, arg2_1, buf526, 256, stream=stream0)
        buf528 = buf525; del buf525  # reuse
        # Topologically Sorted Source Nodes: [value_states_38], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf523, reinterpret_tensor(arg179_1, (2048, 256), (1, 2048), 0), out=buf528)
        del arg179_1
        del buf523
        buf529 = buf499; del buf499  # reuse
        # Topologically Sorted Source Nodes: [setitem_39, view_59, value_states_39], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_108.run(buf526, buf504, buf529, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_39, view_59, value_states_39], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf528, buf529, 256, stream=stream0)
        del buf528
        buf531 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_109.run(buf529, buf526, buf504, buf531, 117440512, stream=stream0)
        del buf504
        buf532 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_57, query_states_77, q_19, chunk_38, mul_173, neg_38, cat_38, mul_174, q_embed_19, attn_output_76], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf524, arg1_1, arg0_1, arg2_1, buf532, 2048, stream=stream0)
        buf533 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_76], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_110.run(buf531, buf533, 2048, 8192, stream=stream0)
        buf534 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_57, query_states_77, q_19, chunk_38, mul_173, neg_38, cat_38, mul_174, q_embed_19, attn_output_76], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf532, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf533, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf534)
        buf538 = reinterpret_tensor(buf534, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf534  # reuse
        # Topologically Sorted Source Nodes: [attn_output_76, arange_19, attn_mask_19], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf538, arg1_1, 16, 8192, stream=stream0)
        buf539 = reinterpret_tensor(buf533, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf533  # reuse
        # Topologically Sorted Source Nodes: [setitem_39, attn_output_76], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_111.run(buf531, buf539, 16777216, stream=stream0)
        buf540 = reinterpret_tensor(buf532, (16, 1, 128), (128, 128, 1), 0); del buf532  # reuse
        # Topologically Sorted Source Nodes: [attn_output_76, arange_19, attn_mask_19, setitem_39], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf538, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf539, (16, 8192, 128), (1048576, 128, 1), 0), out=buf540)
        del buf538
        del buf539
        buf541 = reinterpret_tensor(buf524, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf524  # reuse
        # Topologically Sorted Source Nodes: [attn_output_76], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf540, buf541, 2048, stream=stream0)
        buf542 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_76, transpose_79, attn_output_78, attn_output_79], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf541, (1, 2048), (0, 1), 0), reinterpret_tensor(arg180_1, (2048, 2048), (1, 2048), 0), out=buf542)
        del arg180_1
        buf544 = reinterpret_tensor(buf541, (1, 2048), (2048, 1), 0); del buf541  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_91, hidden_states_94, hidden_states_96, to_158, pow_40, variance_39, add_118, rsqrt_39, mul_177, hidden_39, hidden_states_97], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf494, buf515, buf521, buf542, arg181_1, buf544, 1, 2048, stream=stream0)
        del arg181_1
        buf545 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_137], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf544, reinterpret_tensor(arg182_1, (2048, 6144), (1, 2048), 0), out=buf545)
        del arg182_1
        buf546 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_138], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf544, reinterpret_tensor(arg183_1, (2048, 6144), (1, 2048), 0), out=buf546)
        del arg183_1
        buf547 = buf545; del buf545  # reuse
        # Topologically Sorted Source Nodes: [silu_19, mul_179], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf547, buf546, 6144, stream=stream0)
        del buf546
        buf548 = buf544; del buf544  # reuse
        # Topologically Sorted Source Nodes: [silu_19, mul_179, hidden_states_98], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf547, reinterpret_tensor(arg184_1, (6144, 2048), (1, 6144), 0), out=buf548)
        del arg184_1
        del buf547
        buf549 = buf494; del buf494  # reuse
        buf551 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_91, hidden_states_94, hidden_states_96, hidden_states_99, to_160, pow_41, variance_40, add_120, rsqrt_40, mul_180, hidden_40, hidden_states_100], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf549, buf515, buf521, buf542, buf548, arg185_1, buf551, 1, 2048, stream=stream0)
        del arg185_1
        del buf515
        del buf521
        del buf542
        buf552 = buf548; del buf548  # reuse
        # Topologically Sorted Source Nodes: [query_states_80], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf551, reinterpret_tensor(arg186_1, (2048, 2048), (1, 2048), 0), out=buf552)
        del arg186_1
        buf553 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_60], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf551, reinterpret_tensor(arg187_1, (2048, 256), (1, 2048), 0), out=buf553)
        del arg187_1
        buf554 = buf529; del buf529  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_61, key_states_61, k_20, chunk_41, setitem_40, mul_184, neg_41, cat_41, mul_185, k_embed_20, key_states_62], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_112.run(buf531, buf554, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_61, key_states_61, k_20, chunk_41, setitem_40, mul_184, neg_41, cat_41, mul_185, k_embed_20, key_states_62], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf553, arg0_1, arg2_1, buf554, 256, stream=stream0)
        buf556 = buf553; del buf553  # reuse
        # Topologically Sorted Source Nodes: [value_states_40], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf551, reinterpret_tensor(arg188_1, (2048, 256), (1, 2048), 0), out=buf556)
        del arg188_1
        del buf551
        buf557 = buf526; del buf526  # reuse
        # Topologically Sorted Source Nodes: [setitem_41, view_62, value_states_41], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_113.run(buf554, buf531, buf557, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_41, view_62, value_states_41], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf556, buf557, 256, stream=stream0)
        del buf556
        buf559 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_114.run(buf557, buf554, buf531, buf559, 117440512, stream=stream0)
        del buf531
        buf560 = reinterpret_tensor(buf540, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf540  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_60, query_states_81, q_20, chunk_40, mul_182, neg_40, cat_40, mul_183, q_embed_20, attn_output_80], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf552, arg1_1, arg0_1, arg2_1, buf560, 2048, stream=stream0)
        buf561 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_80], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_115.run(buf559, buf561, 2048, 8192, stream=stream0)
        buf562 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_60, query_states_81, q_20, chunk_40, mul_182, neg_40, cat_40, mul_183, q_embed_20, attn_output_80], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf560, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf561, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf562)
        buf566 = reinterpret_tensor(buf562, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf562  # reuse
        # Topologically Sorted Source Nodes: [attn_output_80, arange_20, attn_mask_20], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf566, arg1_1, 16, 8192, stream=stream0)
        buf567 = reinterpret_tensor(buf561, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf561  # reuse
        # Topologically Sorted Source Nodes: [setitem_41, attn_output_80], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_116.run(buf559, buf567, 16777216, stream=stream0)
        buf568 = reinterpret_tensor(buf560, (16, 1, 128), (128, 128, 1), 0); del buf560  # reuse
        # Topologically Sorted Source Nodes: [attn_output_80, arange_20, attn_mask_20, setitem_41], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf566, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf567, (16, 8192, 128), (1048576, 128, 1), 0), out=buf568)
        del buf566
        del buf567
        buf569 = reinterpret_tensor(buf552, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf552  # reuse
        # Topologically Sorted Source Nodes: [attn_output_80], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf568, buf569, 2048, stream=stream0)
        del buf568
        buf570 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_80, transpose_83, attn_output_82, attn_output_83], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf569, (1, 2048), (0, 1), 0), reinterpret_tensor(arg189_1, (2048, 2048), (1, 2048), 0), out=buf570)
        del arg189_1
        buf572 = reinterpret_tensor(buf569, (1, 2048), (2048, 1), 0); del buf569  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_101, to_166, pow_42, variance_41, add_124, rsqrt_41, mul_186, hidden_41, hidden_states_102], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf549, buf570, arg190_1, buf572, 1, 2048, stream=stream0)
        del arg190_1
        buf573 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_144], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf572, reinterpret_tensor(arg191_1, (2048, 6144), (1, 2048), 0), out=buf573)
        del arg191_1
        buf574 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_145], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf572, reinterpret_tensor(arg192_1, (2048, 6144), (1, 2048), 0), out=buf574)
        del arg192_1
        buf575 = buf573; del buf573  # reuse
        # Topologically Sorted Source Nodes: [silu_20, mul_188], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf575, buf574, 6144, stream=stream0)
        del buf574
        buf576 = buf572; del buf572  # reuse
        # Topologically Sorted Source Nodes: [silu_20, mul_188, hidden_states_103], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf575, reinterpret_tensor(arg193_1, (6144, 2048), (1, 6144), 0), out=buf576)
        del arg193_1
        del buf575
        buf578 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_101, hidden_states_104, to_168, pow_43, variance_42, add_126, rsqrt_42, mul_189, hidden_42, hidden_states_105], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf549, buf570, buf576, arg194_1, buf578, 1, 2048, stream=stream0)
        del arg194_1
        buf579 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_84], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf578, reinterpret_tensor(arg195_1, (2048, 2048), (1, 2048), 0), out=buf579)
        del arg195_1
        buf580 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_63], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf578, reinterpret_tensor(arg196_1, (2048, 256), (1, 2048), 0), out=buf580)
        del arg196_1
        buf581 = buf557; del buf557  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_64, key_states_64, k_21, chunk_43, setitem_42, mul_193, neg_43, cat_43, mul_194, k_embed_21, key_states_65], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_117.run(buf559, buf581, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_64, key_states_64, k_21, chunk_43, setitem_42, mul_193, neg_43, cat_43, mul_194, k_embed_21, key_states_65], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf580, arg0_1, arg2_1, buf581, 256, stream=stream0)
        buf583 = buf580; del buf580  # reuse
        # Topologically Sorted Source Nodes: [value_states_42], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf578, reinterpret_tensor(arg197_1, (2048, 256), (1, 2048), 0), out=buf583)
        del arg197_1
        del buf578
        buf584 = buf554; del buf554  # reuse
        # Topologically Sorted Source Nodes: [setitem_43, view_65, value_states_43], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_118.run(buf581, buf559, buf584, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_43, view_65, value_states_43], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf583, buf584, 256, stream=stream0)
        del buf583
        buf586 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_119.run(buf584, buf581, buf559, buf586, 117440512, stream=stream0)
        del buf559
        buf587 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_63, query_states_85, q_21, chunk_42, mul_191, neg_42, cat_42, mul_192, q_embed_21, attn_output_84], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf579, arg1_1, arg0_1, arg2_1, buf587, 2048, stream=stream0)
        buf588 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_84], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_120.run(buf586, buf588, 2048, 8192, stream=stream0)
        buf589 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_63, query_states_85, q_21, chunk_42, mul_191, neg_42, cat_42, mul_192, q_embed_21, attn_output_84], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf587, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf588, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf589)
        buf593 = reinterpret_tensor(buf589, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf589  # reuse
        # Topologically Sorted Source Nodes: [attn_output_84, arange_21, attn_mask_21], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf593, arg1_1, 16, 8192, stream=stream0)
        buf594 = reinterpret_tensor(buf588, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf588  # reuse
        # Topologically Sorted Source Nodes: [setitem_43, attn_output_84], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_121.run(buf586, buf594, 16777216, stream=stream0)
        buf595 = reinterpret_tensor(buf587, (16, 1, 128), (128, 128, 1), 0); del buf587  # reuse
        # Topologically Sorted Source Nodes: [attn_output_84, arange_21, attn_mask_21, setitem_43], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf593, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf594, (16, 8192, 128), (1048576, 128, 1), 0), out=buf595)
        del buf593
        del buf594
        buf596 = reinterpret_tensor(buf579, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf579  # reuse
        # Topologically Sorted Source Nodes: [attn_output_84], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf595, buf596, 2048, stream=stream0)
        buf597 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_84, transpose_87, attn_output_86, attn_output_87], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf596, (1, 2048), (0, 1), 0), reinterpret_tensor(arg198_1, (2048, 2048), (1, 2048), 0), out=buf597)
        del arg198_1
        buf599 = reinterpret_tensor(buf596, (1, 2048), (2048, 1), 0); del buf596  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_101, hidden_states_104, hidden_states_106, to_174, pow_44, variance_43, add_130, rsqrt_43, mul_195, hidden_43, hidden_states_107], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf549, buf570, buf576, buf597, arg199_1, buf599, 1, 2048, stream=stream0)
        del arg199_1
        buf600 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_151], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf599, reinterpret_tensor(arg200_1, (2048, 6144), (1, 2048), 0), out=buf600)
        del arg200_1
        buf601 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_152], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf599, reinterpret_tensor(arg201_1, (2048, 6144), (1, 2048), 0), out=buf601)
        del arg201_1
        buf602 = buf600; del buf600  # reuse
        # Topologically Sorted Source Nodes: [silu_21, mul_197], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf602, buf601, 6144, stream=stream0)
        del buf601
        buf603 = buf599; del buf599  # reuse
        # Topologically Sorted Source Nodes: [silu_21, mul_197, hidden_states_108], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf602, reinterpret_tensor(arg202_1, (6144, 2048), (1, 6144), 0), out=buf603)
        del arg202_1
        del buf602
        buf604 = buf549; del buf549  # reuse
        buf606 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_101, hidden_states_104, hidden_states_106, hidden_states_109, to_176, pow_45, variance_44, add_132, rsqrt_44, mul_198, hidden_44, hidden_states_110], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf604, buf570, buf576, buf597, buf603, arg203_1, buf606, 1, 2048, stream=stream0)
        del arg203_1
        del buf570
        del buf576
        del buf597
        buf607 = buf603; del buf603  # reuse
        # Topologically Sorted Source Nodes: [query_states_88], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf606, reinterpret_tensor(arg204_1, (2048, 2048), (1, 2048), 0), out=buf607)
        del arg204_1
        buf608 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_66], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf606, reinterpret_tensor(arg205_1, (2048, 256), (1, 2048), 0), out=buf608)
        del arg205_1
        buf609 = buf584; del buf584  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_67, key_states_67, k_22, chunk_45, setitem_44, mul_202, neg_45, cat_45, mul_203, k_embed_22, key_states_68], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_122.run(buf586, buf609, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_67, key_states_67, k_22, chunk_45, setitem_44, mul_202, neg_45, cat_45, mul_203, k_embed_22, key_states_68], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf608, arg0_1, arg2_1, buf609, 256, stream=stream0)
        buf611 = buf608; del buf608  # reuse
        # Topologically Sorted Source Nodes: [value_states_44], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf606, reinterpret_tensor(arg206_1, (2048, 256), (1, 2048), 0), out=buf611)
        del arg206_1
        del buf606
        buf612 = buf581; del buf581  # reuse
        # Topologically Sorted Source Nodes: [setitem_45, view_68, value_states_45], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_123.run(buf609, buf586, buf612, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_45, view_68, value_states_45], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf611, buf612, 256, stream=stream0)
        del buf611
        buf614 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_124.run(buf612, buf609, buf586, buf614, 117440512, stream=stream0)
        del buf586
        buf615 = reinterpret_tensor(buf595, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf595  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_66, query_states_89, q_22, chunk_44, mul_200, neg_44, cat_44, mul_201, q_embed_22, attn_output_88], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf607, arg1_1, arg0_1, arg2_1, buf615, 2048, stream=stream0)
        buf616 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_88], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_125.run(buf614, buf616, 2048, 8192, stream=stream0)
        buf617 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_66, query_states_89, q_22, chunk_44, mul_200, neg_44, cat_44, mul_201, q_embed_22, attn_output_88], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf615, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf616, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf617)
        buf621 = reinterpret_tensor(buf617, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf617  # reuse
        # Topologically Sorted Source Nodes: [attn_output_88, arange_22, attn_mask_22], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf621, arg1_1, 16, 8192, stream=stream0)
        buf622 = reinterpret_tensor(buf616, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf616  # reuse
        # Topologically Sorted Source Nodes: [setitem_45, attn_output_88], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_126.run(buf614, buf622, 16777216, stream=stream0)
        buf623 = reinterpret_tensor(buf615, (16, 1, 128), (128, 128, 1), 0); del buf615  # reuse
        # Topologically Sorted Source Nodes: [attn_output_88, arange_22, attn_mask_22, setitem_45], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf621, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf622, (16, 8192, 128), (1048576, 128, 1), 0), out=buf623)
        del buf621
        del buf622
        buf624 = reinterpret_tensor(buf607, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf607  # reuse
        # Topologically Sorted Source Nodes: [attn_output_88], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf623, buf624, 2048, stream=stream0)
        del buf623
        buf625 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_88, transpose_91, attn_output_90, attn_output_91], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf624, (1, 2048), (0, 1), 0), reinterpret_tensor(arg207_1, (2048, 2048), (1, 2048), 0), out=buf625)
        del arg207_1
        buf627 = reinterpret_tensor(buf624, (1, 2048), (2048, 1), 0); del buf624  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_111, to_182, pow_46, variance_45, add_136, rsqrt_45, mul_204, hidden_45, hidden_states_112], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf604, buf625, arg208_1, buf627, 1, 2048, stream=stream0)
        del arg208_1
        buf628 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_158], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf627, reinterpret_tensor(arg209_1, (2048, 6144), (1, 2048), 0), out=buf628)
        del arg209_1
        buf629 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_159], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf627, reinterpret_tensor(arg210_1, (2048, 6144), (1, 2048), 0), out=buf629)
        del arg210_1
        buf630 = buf628; del buf628  # reuse
        # Topologically Sorted Source Nodes: [silu_22, mul_206], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf630, buf629, 6144, stream=stream0)
        del buf629
        buf631 = buf627; del buf627  # reuse
        # Topologically Sorted Source Nodes: [silu_22, mul_206, hidden_states_113], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf630, reinterpret_tensor(arg211_1, (6144, 2048), (1, 6144), 0), out=buf631)
        del arg211_1
        del buf630
        buf633 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_111, hidden_states_114, to_184, pow_47, variance_46, add_138, rsqrt_46, mul_207, hidden_46, hidden_states_115], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf604, buf625, buf631, arg212_1, buf633, 1, 2048, stream=stream0)
        del arg212_1
        buf634 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_92], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf633, reinterpret_tensor(arg213_1, (2048, 2048), (1, 2048), 0), out=buf634)
        del arg213_1
        buf635 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_69], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf633, reinterpret_tensor(arg214_1, (2048, 256), (1, 2048), 0), out=buf635)
        del arg214_1
        buf636 = buf612; del buf612  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_70, key_states_70, k_23, chunk_47, setitem_46, mul_211, neg_47, cat_47, mul_212, k_embed_23, key_states_71], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_127.run(buf614, buf636, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_70, key_states_70, k_23, chunk_47, setitem_46, mul_211, neg_47, cat_47, mul_212, k_embed_23, key_states_71], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf635, arg0_1, arg2_1, buf636, 256, stream=stream0)
        buf638 = buf635; del buf635  # reuse
        # Topologically Sorted Source Nodes: [value_states_46], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf633, reinterpret_tensor(arg215_1, (2048, 256), (1, 2048), 0), out=buf638)
        del arg215_1
        del buf633
        buf639 = buf609; del buf609  # reuse
        # Topologically Sorted Source Nodes: [setitem_47, view_71, value_states_47], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_128.run(buf636, buf614, buf639, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_47, view_71, value_states_47], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf638, buf639, 256, stream=stream0)
        del buf638
        buf641 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_129.run(buf639, buf636, buf614, buf641, 117440512, stream=stream0)
        del buf614
        buf642 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_69, query_states_93, q_23, chunk_46, mul_209, neg_46, cat_46, mul_210, q_embed_23, attn_output_92], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf634, arg1_1, arg0_1, arg2_1, buf642, 2048, stream=stream0)
        buf643 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_92], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_130.run(buf641, buf643, 2048, 8192, stream=stream0)
        buf644 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_69, query_states_93, q_23, chunk_46, mul_209, neg_46, cat_46, mul_210, q_embed_23, attn_output_92], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf642, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf643, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf644)
        buf648 = reinterpret_tensor(buf644, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf644  # reuse
        # Topologically Sorted Source Nodes: [attn_output_92, arange_23, attn_mask_23], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf648, arg1_1, 16, 8192, stream=stream0)
        buf649 = reinterpret_tensor(buf643, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf643  # reuse
        # Topologically Sorted Source Nodes: [setitem_47, attn_output_92], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_131.run(buf641, buf649, 16777216, stream=stream0)
        buf650 = reinterpret_tensor(buf642, (16, 1, 128), (128, 128, 1), 0); del buf642  # reuse
        # Topologically Sorted Source Nodes: [attn_output_92, arange_23, attn_mask_23, setitem_47], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf648, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf649, (16, 8192, 128), (1048576, 128, 1), 0), out=buf650)
        del buf648
        del buf649
        buf651 = reinterpret_tensor(buf634, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf634  # reuse
        # Topologically Sorted Source Nodes: [attn_output_92], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf650, buf651, 2048, stream=stream0)
        buf652 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_92, transpose_95, attn_output_94, attn_output_95], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf651, (1, 2048), (0, 1), 0), reinterpret_tensor(arg216_1, (2048, 2048), (1, 2048), 0), out=buf652)
        del arg216_1
        buf654 = reinterpret_tensor(buf651, (1, 2048), (2048, 1), 0); del buf651  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_111, hidden_states_114, hidden_states_116, to_190, pow_48, variance_47, add_142, rsqrt_47, mul_213, hidden_47, hidden_states_117], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf604, buf625, buf631, buf652, arg217_1, buf654, 1, 2048, stream=stream0)
        del arg217_1
        buf655 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_165], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf654, reinterpret_tensor(arg218_1, (2048, 6144), (1, 2048), 0), out=buf655)
        del arg218_1
        buf656 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_166], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf654, reinterpret_tensor(arg219_1, (2048, 6144), (1, 2048), 0), out=buf656)
        del arg219_1
        buf657 = buf655; del buf655  # reuse
        # Topologically Sorted Source Nodes: [silu_23, mul_215], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf657, buf656, 6144, stream=stream0)
        del buf656
        buf658 = buf654; del buf654  # reuse
        # Topologically Sorted Source Nodes: [silu_23, mul_215, hidden_states_118], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf657, reinterpret_tensor(arg220_1, (6144, 2048), (1, 6144), 0), out=buf658)
        del arg220_1
        del buf657
        buf659 = buf604; del buf604  # reuse
        buf661 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_111, hidden_states_114, hidden_states_116, hidden_states_119, to_192, pow_49, variance_48, add_144, rsqrt_48, mul_216, hidden_48, hidden_states_120], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf659, buf625, buf631, buf652, buf658, arg221_1, buf661, 1, 2048, stream=stream0)
        del arg221_1
        del buf625
        del buf631
        del buf652
        buf662 = buf658; del buf658  # reuse
        # Topologically Sorted Source Nodes: [query_states_96], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf661, reinterpret_tensor(arg222_1, (2048, 2048), (1, 2048), 0), out=buf662)
        del arg222_1
        buf663 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_72], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf661, reinterpret_tensor(arg223_1, (2048, 256), (1, 2048), 0), out=buf663)
        del arg223_1
        buf664 = buf639; del buf639  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_73, key_states_73, k_24, chunk_49, setitem_48, mul_220, neg_49, cat_49, mul_221, k_embed_24, key_states_74], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_132.run(buf641, buf664, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_73, key_states_73, k_24, chunk_49, setitem_48, mul_220, neg_49, cat_49, mul_221, k_embed_24, key_states_74], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf663, arg0_1, arg2_1, buf664, 256, stream=stream0)
        buf666 = buf663; del buf663  # reuse
        # Topologically Sorted Source Nodes: [value_states_48], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf661, reinterpret_tensor(arg224_1, (2048, 256), (1, 2048), 0), out=buf666)
        del arg224_1
        del buf661
        buf667 = buf636; del buf636  # reuse
        # Topologically Sorted Source Nodes: [setitem_49, view_74, value_states_49], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_133.run(buf664, buf641, buf667, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_49, view_74, value_states_49], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf666, buf667, 256, stream=stream0)
        del buf666
        buf669 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_134.run(buf667, buf664, buf641, buf669, 117440512, stream=stream0)
        del buf641
        buf670 = reinterpret_tensor(buf650, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf650  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_72, query_states_97, q_24, chunk_48, mul_218, neg_48, cat_48, mul_219, q_embed_24, attn_output_96], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf662, arg1_1, arg0_1, arg2_1, buf670, 2048, stream=stream0)
        buf671 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_96], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_135.run(buf669, buf671, 2048, 8192, stream=stream0)
        buf672 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_72, query_states_97, q_24, chunk_48, mul_218, neg_48, cat_48, mul_219, q_embed_24, attn_output_96], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf670, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf671, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf672)
        buf676 = reinterpret_tensor(buf672, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf672  # reuse
        # Topologically Sorted Source Nodes: [attn_output_96, arange_24, attn_mask_24], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf676, arg1_1, 16, 8192, stream=stream0)
        buf677 = reinterpret_tensor(buf671, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf671  # reuse
        # Topologically Sorted Source Nodes: [setitem_49, attn_output_96], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_136.run(buf669, buf677, 16777216, stream=stream0)
        buf678 = reinterpret_tensor(buf670, (16, 1, 128), (128, 128, 1), 0); del buf670  # reuse
        # Topologically Sorted Source Nodes: [attn_output_96, arange_24, attn_mask_24, setitem_49], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf676, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf677, (16, 8192, 128), (1048576, 128, 1), 0), out=buf678)
        del buf676
        del buf677
        buf679 = reinterpret_tensor(buf662, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf662  # reuse
        # Topologically Sorted Source Nodes: [attn_output_96], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf678, buf679, 2048, stream=stream0)
        del buf678
        buf680 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_96, transpose_99, attn_output_98, attn_output_99], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf679, (1, 2048), (0, 1), 0), reinterpret_tensor(arg225_1, (2048, 2048), (1, 2048), 0), out=buf680)
        del arg225_1
        buf682 = reinterpret_tensor(buf679, (1, 2048), (2048, 1), 0); del buf679  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_121, to_198, pow_50, variance_49, add_148, rsqrt_49, mul_222, hidden_49, hidden_states_122], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf659, buf680, arg226_1, buf682, 1, 2048, stream=stream0)
        del arg226_1
        buf683 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_172], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf682, reinterpret_tensor(arg227_1, (2048, 6144), (1, 2048), 0), out=buf683)
        del arg227_1
        buf684 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_173], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf682, reinterpret_tensor(arg228_1, (2048, 6144), (1, 2048), 0), out=buf684)
        del arg228_1
        buf685 = buf683; del buf683  # reuse
        # Topologically Sorted Source Nodes: [silu_24, mul_224], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf685, buf684, 6144, stream=stream0)
        del buf684
        buf686 = buf682; del buf682  # reuse
        # Topologically Sorted Source Nodes: [silu_24, mul_224, hidden_states_123], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf685, reinterpret_tensor(arg229_1, (6144, 2048), (1, 6144), 0), out=buf686)
        del arg229_1
        del buf685
        buf688 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_121, hidden_states_124, to_200, pow_51, variance_50, add_150, rsqrt_50, mul_225, hidden_50, hidden_states_125], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf659, buf680, buf686, arg230_1, buf688, 1, 2048, stream=stream0)
        del arg230_1
        buf689 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_100], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf688, reinterpret_tensor(arg231_1, (2048, 2048), (1, 2048), 0), out=buf689)
        del arg231_1
        buf690 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_75], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf688, reinterpret_tensor(arg232_1, (2048, 256), (1, 2048), 0), out=buf690)
        del arg232_1
        buf691 = buf667; del buf667  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_76, key_states_76, k_25, chunk_51, setitem_50, mul_229, neg_51, cat_51, mul_230, k_embed_25, key_states_77], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_137.run(buf669, buf691, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_76, key_states_76, k_25, chunk_51, setitem_50, mul_229, neg_51, cat_51, mul_230, k_embed_25, key_states_77], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf690, arg0_1, arg2_1, buf691, 256, stream=stream0)
        buf693 = buf690; del buf690  # reuse
        # Topologically Sorted Source Nodes: [value_states_50], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf688, reinterpret_tensor(arg233_1, (2048, 256), (1, 2048), 0), out=buf693)
        del arg233_1
        del buf688
        buf694 = buf664; del buf664  # reuse
        # Topologically Sorted Source Nodes: [setitem_51, view_77, value_states_51], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_138.run(buf691, buf669, buf694, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_51, view_77, value_states_51], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf693, buf694, 256, stream=stream0)
        del buf693
        buf696 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_139.run(buf694, buf691, buf669, buf696, 117440512, stream=stream0)
        del buf669
        buf697 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_75, query_states_101, q_25, chunk_50, mul_227, neg_50, cat_50, mul_228, q_embed_25, attn_output_100], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf689, arg1_1, arg0_1, arg2_1, buf697, 2048, stream=stream0)
        buf698 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_100], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_140.run(buf696, buf698, 2048, 8192, stream=stream0)
        buf699 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_75, query_states_101, q_25, chunk_50, mul_227, neg_50, cat_50, mul_228, q_embed_25, attn_output_100], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf697, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf698, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf699)
        buf703 = reinterpret_tensor(buf699, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf699  # reuse
        # Topologically Sorted Source Nodes: [attn_output_100, arange_25, attn_mask_25], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf703, arg1_1, 16, 8192, stream=stream0)
        buf704 = reinterpret_tensor(buf698, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf698  # reuse
        # Topologically Sorted Source Nodes: [setitem_51, attn_output_100], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_141.run(buf696, buf704, 16777216, stream=stream0)
        buf705 = reinterpret_tensor(buf697, (16, 1, 128), (128, 128, 1), 0); del buf697  # reuse
        # Topologically Sorted Source Nodes: [attn_output_100, arange_25, attn_mask_25, setitem_51], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf703, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf704, (16, 8192, 128), (1048576, 128, 1), 0), out=buf705)
        del buf703
        del buf704
        buf706 = reinterpret_tensor(buf689, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf689  # reuse
        # Topologically Sorted Source Nodes: [attn_output_100], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf705, buf706, 2048, stream=stream0)
        buf707 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_100, transpose_103, attn_output_102, attn_output_103], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf706, (1, 2048), (0, 1), 0), reinterpret_tensor(arg234_1, (2048, 2048), (1, 2048), 0), out=buf707)
        del arg234_1
        buf709 = reinterpret_tensor(buf706, (1, 2048), (2048, 1), 0); del buf706  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_121, hidden_states_124, hidden_states_126, to_206, pow_52, variance_51, add_154, rsqrt_51, mul_231, hidden_51, hidden_states_127], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf659, buf680, buf686, buf707, arg235_1, buf709, 1, 2048, stream=stream0)
        del arg235_1
        buf710 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_179], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf709, reinterpret_tensor(arg236_1, (2048, 6144), (1, 2048), 0), out=buf710)
        del arg236_1
        buf711 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_180], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf709, reinterpret_tensor(arg237_1, (2048, 6144), (1, 2048), 0), out=buf711)
        del arg237_1
        buf712 = buf710; del buf710  # reuse
        # Topologically Sorted Source Nodes: [silu_25, mul_233], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf712, buf711, 6144, stream=stream0)
        del buf711
        buf713 = buf709; del buf709  # reuse
        # Topologically Sorted Source Nodes: [silu_25, mul_233, hidden_states_128], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf712, reinterpret_tensor(arg238_1, (6144, 2048), (1, 6144), 0), out=buf713)
        del arg238_1
        del buf712
        buf714 = buf659; del buf659  # reuse
        buf716 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_121, hidden_states_124, hidden_states_126, hidden_states_129, to_208, pow_53, variance_52, add_156, rsqrt_52, mul_234, hidden_52, hidden_states_130], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_31.run(buf714, buf680, buf686, buf707, buf713, arg239_1, buf716, 1, 2048, stream=stream0)
        del arg239_1
        del buf680
        del buf686
        del buf707
        buf717 = buf713; del buf713  # reuse
        # Topologically Sorted Source Nodes: [query_states_104], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf716, reinterpret_tensor(arg240_1, (2048, 2048), (1, 2048), 0), out=buf717)
        del arg240_1
        buf718 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_78], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf716, reinterpret_tensor(arg241_1, (2048, 256), (1, 2048), 0), out=buf718)
        del arg241_1
        buf719 = buf694; del buf694  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_79, key_states_79, k_26, chunk_53, setitem_52, mul_238, neg_53, cat_53, mul_239, k_embed_26, key_states_80], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_142.run(buf696, buf719, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_79, key_states_79, k_26, chunk_53, setitem_52, mul_238, neg_53, cat_53, mul_239, k_embed_26, key_states_80], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf718, arg0_1, arg2_1, buf719, 256, stream=stream0)
        buf721 = buf718; del buf718  # reuse
        # Topologically Sorted Source Nodes: [value_states_52], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf716, reinterpret_tensor(arg242_1, (2048, 256), (1, 2048), 0), out=buf721)
        del arg242_1
        del buf716
        buf722 = buf691; del buf691  # reuse
        # Topologically Sorted Source Nodes: [setitem_53, view_80, value_states_53], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_143.run(buf719, buf696, buf722, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_53, view_80, value_states_53], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf721, buf722, 256, stream=stream0)
        del buf721
        buf724 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_144.run(buf722, buf719, buf696, buf724, 117440512, stream=stream0)
        del buf696
        buf725 = reinterpret_tensor(buf705, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf705  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_78, query_states_105, q_26, chunk_52, mul_236, neg_52, cat_52, mul_237, q_embed_26, attn_output_104], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf717, arg1_1, arg0_1, arg2_1, buf725, 2048, stream=stream0)
        buf726 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_104], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_145.run(buf724, buf726, 2048, 8192, stream=stream0)
        buf727 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_78, query_states_105, q_26, chunk_52, mul_236, neg_52, cat_52, mul_237, q_embed_26, attn_output_104], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf725, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf726, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf727)
        buf731 = reinterpret_tensor(buf727, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf727  # reuse
        # Topologically Sorted Source Nodes: [attn_output_104, arange_26, attn_mask_26], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf731, arg1_1, 16, 8192, stream=stream0)
        buf732 = reinterpret_tensor(buf726, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf726  # reuse
        # Topologically Sorted Source Nodes: [setitem_53, attn_output_104], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_146.run(buf724, buf732, 16777216, stream=stream0)
        buf733 = reinterpret_tensor(buf725, (16, 1, 128), (128, 128, 1), 0); del buf725  # reuse
        # Topologically Sorted Source Nodes: [attn_output_104, arange_26, attn_mask_26, setitem_53], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf731, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf732, (16, 8192, 128), (1048576, 128, 1), 0), out=buf733)
        del buf731
        del buf732
        buf734 = reinterpret_tensor(buf717, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf717  # reuse
        # Topologically Sorted Source Nodes: [attn_output_104], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf733, buf734, 2048, stream=stream0)
        del buf733
        buf735 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_104, transpose_107, attn_output_106, attn_output_107], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf734, (1, 2048), (0, 1), 0), reinterpret_tensor(arg243_1, (2048, 2048), (1, 2048), 0), out=buf735)
        del arg243_1
        buf737 = reinterpret_tensor(buf734, (1, 2048), (2048, 1), 0); del buf734  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_131, to_214, pow_54, variance_53, add_160, rsqrt_53, mul_240, hidden_53, hidden_states_132], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11.run(buf714, buf735, arg244_1, buf737, 1, 2048, stream=stream0)
        del arg244_1
        buf738 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_186], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf737, reinterpret_tensor(arg245_1, (2048, 6144), (1, 2048), 0), out=buf738)
        del arg245_1
        buf739 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_187], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf737, reinterpret_tensor(arg246_1, (2048, 6144), (1, 2048), 0), out=buf739)
        del arg246_1
        buf740 = buf738; del buf738  # reuse
        # Topologically Sorted Source Nodes: [silu_26, mul_242], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf740, buf739, 6144, stream=stream0)
        del buf739
        buf741 = buf737; del buf737  # reuse
        # Topologically Sorted Source Nodes: [silu_26, mul_242, hidden_states_133], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf740, reinterpret_tensor(arg247_1, (6144, 2048), (1, 6144), 0), out=buf741)
        del arg247_1
        del buf740
        buf743 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_131, hidden_states_134, to_216, pow_55, variance_54, add_162, rsqrt_54, mul_243, hidden_54, hidden_states_135], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_13.run(buf714, buf735, buf741, arg248_1, buf743, 1, 2048, stream=stream0)
        del arg248_1
        buf744 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_108], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf743, reinterpret_tensor(arg249_1, (2048, 2048), (1, 2048), 0), out=buf744)
        del arg249_1
        buf745 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_81], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf743, reinterpret_tensor(arg250_1, (2048, 256), (1, 2048), 0), out=buf745)
        del arg250_1
        buf746 = buf722; del buf722  # reuse
        # Topologically Sorted Source Nodes: [cos, sin, view_82, key_states_82, k_27, chunk_55, setitem_54, mul_247, neg_55, cat_55, mul_248, k_embed_27, key_states_83], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_147.run(buf724, buf746, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [cos, sin, view_82, key_states_82, k_27, chunk_55, setitem_54, mul_247, neg_55, cat_55, mul_248, k_embed_27, key_states_83], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.select, aten.mul, aten.neg, aten.cat, aten.add, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2.run(arg1_1, buf745, arg0_1, arg2_1, buf746, 256, stream=stream0)
        buf748 = buf745; del buf745  # reuse
        # Topologically Sorted Source Nodes: [value_states_54], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf743, reinterpret_tensor(arg251_1, (2048, 256), (1, 2048), 0), out=buf748)
        del arg251_1
        del buf743
        buf749 = buf719; del buf719  # reuse
        # Topologically Sorted Source Nodes: [setitem_55, view_83, value_states_55], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_148.run(buf746, buf724, buf749, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_55, view_83, value_states_55], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_4.run(arg1_1, buf748, buf749, 256, stream=stream0)
        del buf748
        buf751 = empty_strided_cuda((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 117440512, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: [aten.copy_]
        stream0 = get_raw_stream(0)
        triton_poi_fused_copy__149.run(buf749, buf746, buf724, buf751, arg3_1, 117440512, stream=stream0)
        del arg3_1
        del buf724
        del buf746
        del buf749
        buf752 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_81, query_states_109, q_27, chunk_54, mul_245, neg_54, cat_54, mul_246, q_embed_27, attn_output_108], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_cat_index_mul_neg_split_transpose_view_6.run(buf744, arg1_1, arg0_1, arg2_1, buf752, 2048, stream=stream0)
        del arg0_1
        del arg2_1
        buf753 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_108], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_150.run(buf751, buf753, 2048, 8192, stream=stream0)
        buf754 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [cos, sin, view_81, query_states_109, q_27, chunk_54, mul_245, neg_54, cat_54, mul_246, q_embed_27, attn_output_108], Original ATen: [aten.index, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf752, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf753, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf754)
        buf758 = reinterpret_tensor(buf754, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf754  # reuse
        # Topologically Sorted Source Nodes: [attn_output_108, arange_27, attn_mask_27], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_8.run(buf758, arg1_1, 16, 8192, stream=stream0)
        del arg1_1
        buf759 = reinterpret_tensor(buf753, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf753  # reuse
        # Topologically Sorted Source Nodes: [setitem_55, attn_output_108], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_151.run(buf751, buf759, 16777216, stream=stream0)
        del buf751
        buf760 = reinterpret_tensor(buf752, (16, 1, 128), (128, 128, 1), 0); del buf752  # reuse
        # Topologically Sorted Source Nodes: [attn_output_108, arange_27, attn_mask_27, setitem_55], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf758, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf759, (16, 8192, 128), (1048576, 128, 1), 0), out=buf760)
        del buf758
        del buf759
        buf761 = reinterpret_tensor(buf744, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf744  # reuse
        # Topologically Sorted Source Nodes: [attn_output_108], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_10.run(buf760, buf761, 2048, stream=stream0)
        del buf760
        buf762 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_108, transpose_111, attn_output_110, attn_output_111], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf761, (1, 2048), (0, 1), 0), reinterpret_tensor(arg252_1, (2048, 2048), (1, 2048), 0), out=buf762)
        del arg252_1
        buf764 = reinterpret_tensor(buf761, (1, 2048), (2048, 1), 0); del buf761  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_131, hidden_states_134, hidden_states_136, to_222, pow_56, variance_55, add_166, rsqrt_55, mul_249, hidden_55, hidden_states_137], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf714, buf735, buf741, buf762, arg253_1, buf764, 1, 2048, stream=stream0)
        del arg253_1
        buf765 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_193], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf764, reinterpret_tensor(arg254_1, (2048, 6144), (1, 2048), 0), out=buf765)
        del arg254_1
        buf766 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_194], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf764, reinterpret_tensor(arg255_1, (2048, 6144), (1, 2048), 0), out=buf766)
        del arg255_1
        buf767 = buf765; del buf765  # reuse
        # Topologically Sorted Source Nodes: [silu_27, mul_251], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_12.run(buf767, buf766, 6144, stream=stream0)
        del buf766
        buf768 = buf764; del buf764  # reuse
        # Topologically Sorted Source Nodes: [silu_27, mul_251, hidden_states_138], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf767, reinterpret_tensor(arg256_1, (6144, 2048), (1, 6144), 0), out=buf768)
        del arg256_1
        del buf767
        buf769 = buf714; del buf714  # reuse
        buf771 = buf769; del buf769  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_131, hidden_states_134, hidden_states_136, hidden_states_139, to_224, pow_57, variance_56, add_168, rsqrt_56, mul_252, hidden_56, hidden_states_140], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_152.run(buf771, buf735, buf741, buf762, buf768, arg257_1, 1, 2048, stream=stream0)
        del arg257_1
        del buf735
        del buf741
        del buf762
        del buf768
    return (buf771, )


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1 = args
        args.clear()
        partition0_args = [arg5_1, arg4_1, arg6_1, arg7_1, arg3_1, arg1_1, arg0_1, arg2_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1]
        del arg5_1, arg4_1, arg6_1, arg7_1, arg3_1, arg1_1, arg0_1, arg2_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1
        (buf771,) = self.partitions[0](partition0_args)
        del partition0_args
        return (buf771, )

runner = Runner(partitions=[partition_0,])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg1_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int64)
    arg2_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((2, 28, 1, 2, 8192, 128), (58720256, 2097152, 2097152, 1048576, 128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg4_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg5_1 = rand_strided((1, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg6_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg7_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg8_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg9_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg10_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg11_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg12_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg13_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg14_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg15_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg16_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg17_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg18_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg19_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg20_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg21_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg22_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg23_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg24_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg25_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg26_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg27_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg28_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg29_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg30_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg31_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg32_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg33_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg34_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg35_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg36_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg37_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg38_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg39_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg40_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg41_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg42_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg43_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg44_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg45_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg46_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg47_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg48_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg49_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg50_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg51_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg52_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg53_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg54_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg55_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg56_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg57_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg58_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg59_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg60_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg61_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg62_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg63_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg64_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg65_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg66_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg67_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg68_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg69_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg70_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg71_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg72_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg73_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg74_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg75_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg76_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg77_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg78_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg79_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg80_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg81_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg82_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg83_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg84_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg85_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg86_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg87_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg88_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg89_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg90_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg91_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg92_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg93_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg94_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg95_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg96_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg97_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg98_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg99_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg100_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg101_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg102_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg103_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg104_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg105_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg106_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg107_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg108_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg109_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg110_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg111_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg112_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg113_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg114_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg115_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg116_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg117_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg118_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg119_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg120_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg121_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg122_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg123_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg124_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg125_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg126_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg127_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg128_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg129_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg130_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg131_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg132_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg133_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg134_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg135_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg136_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg137_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg138_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg139_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg140_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg141_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg142_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg143_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg144_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg145_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg146_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg147_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg148_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg149_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg150_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg151_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg152_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg153_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg154_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg155_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg156_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg157_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg158_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg159_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg160_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg161_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg162_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg163_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg164_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg165_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg166_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg167_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg168_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg169_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg170_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg171_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg172_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg173_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg174_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg175_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg176_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg177_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg178_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg179_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg180_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg181_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg182_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg183_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg184_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg185_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg186_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg187_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg188_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg189_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg190_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg191_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg192_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg193_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg194_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg195_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg196_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg197_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg198_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg199_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg200_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg201_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg202_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg203_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg204_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg205_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg206_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg207_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg208_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg209_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg210_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg211_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg212_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg213_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg214_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg215_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg216_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg217_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg218_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg219_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg220_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg221_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg222_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg223_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg224_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg225_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg226_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg227_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg228_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg229_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg230_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg231_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg232_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg233_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg234_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg235_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg236_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg237_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg238_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg239_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg240_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg241_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg242_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg243_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg244_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg245_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg246_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg247_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg248_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg249_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg250_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg251_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg252_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg253_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg254_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg255_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg256_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg257_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
