# AOT ID: ['4_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
from torch._C import _cuda_getCurrentRawStream as get_raw_stream



# kernel path: /tmp/torchinductor_shingokuga/ly/cly7md7thmpshkq3dh4ixp35kisdgzvf2kec5norc7twrfgulusd.py
# Topologically Sorted Source Nodes: [to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add => add_13
#   hidden => convert_element_type_1
#   hidden_states => mul_6
#   mul => mul_5
#   pow_1 => pow_1
#   rsqrt => rsqrt
#   to => convert_element_type
#   variance => mean
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %buf0 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf0]
#   %arg2_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg2_1]
#   %convert_element_type : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg3_1, torch.float32), kwargs = {})
#   %pow_1 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type, 2), kwargs = {})
#   %mean : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
#   %add_13 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
#   %rsqrt : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_13,), kwargs = {})
#   %mul_5 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg3_1, %rsqrt), kwargs = {})
#   %convert_element_type_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_5, torch.bfloat16), kwargs = {})
#   %mul_6 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_1, %arg2_1), kwargs = {})
#   return %buf0,%mul_6
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 16384}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0(in_ptr0, in_ptr1, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tmp0.to(tl.float32)
        tmp2 = tmp1 * tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(r0_mask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp6 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tmp6.to(tl.float32)
        tmp8 = tl.full([1, 1], 2048.0, tl.float32)
        tmp9 = (tmp4 / tmp8)
        tmp10 = tl.full([1, 1], 1e-05, tl.float32)
        tmp11 = tmp9 + tmp10
        tmp12 = libdevice.rsqrt(tmp11)
        tmp13 = tmp7 * tmp12
        tmp14 = tmp13.to(tl.float32)
        tmp16 = tmp14 * tmp15
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp16, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ap/capwdp2prh5dl2t7iu27qvtidqr22bfvpzpimuryjt42laznawaq.py
# Topologically Sorted Source Nodes: [key_cache, view_1, key_states_1, setitem], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_cache => select, select_1
#   key_states_1 => permute_4
#   setitem => index_put, view_3
#   view_1 => view_1
# Graph fragment:
#   %copy_ : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %select : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg1_1, 0, 0), kwargs = {})
#   %select_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select, 0, 0), kwargs = {})
#   %view_1 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 1, 2, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_1, [0, 2, 1, 3]), kwargs = {})
#   %view_3 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_4, [2, 1, 128]), kwargs = {})
#   %index_put : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_1, [None, None, %arg7_1], %view_3), kwargs = {})
#   return %index_put
triton_poi_fused_index_put_select_transpose_view_1 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/of/cof54iwtkwd3ajnnwg3sfmxfqlp3s2cavlhrh3vksz3qhfqkwxxc.py
# Topologically Sorted Source Nodes: [key_cache, view_1, key_states_1, setitem], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_cache => select, select_1
#   key_states_1 => permute_4
#   setitem => index_put, view_3
#   view_1 => view_1
# Graph fragment:
#   %arg7_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg7_1]
#   %mm_1 : Tensor "bf16[1, 256][256, 1]cuda:0" = PlaceHolder[target=mm_1]
#   %index_put : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=index_put]
#   %select : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg1_1, 0, 0), kwargs = {})
#   %select_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select, 0, 0), kwargs = {})
#   %view_1 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [1, 1, 2, 128]), kwargs = {})
#   %permute_4 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_1, [0, 2, 1, 3]), kwargs = {})
#   %view_3 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_4, [2, 1, 128]), kwargs = {})
#   %index_put : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_1, [None, None, %arg7_1], %view_3), kwargs = {})
#   return %buf5
triton_poi_fused_index_put_select_transpose_view_2 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_2', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 256}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_2', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 512}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_2(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = (xindex % 128)
    x1 = xindex // 128
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp7 = tl.load(in_ptr1 + (x2), xmask).to(tl.float32)
    tmp2 = tl.full([XBLOCK], 8192, tl.int32)
    tmp3 = tmp1 + tmp2
    tmp4 = tmp1 < 0
    tmp5 = tl.where(tmp4, tmp3, tmp1)
    tl.device_assert((0 <= tmp5) & (tmp5 < 8192), "index out of bounds: 0 <= tmp5 < 8192")
    tl.store(out_ptr0 + (x0 + 128*tmp5 + 1048576*x1), tmp7, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/jn/cjnqgw4td3nxaix2cgenji3jyehezihzhtcoxom54vorgvstkshf.py
# Topologically Sorted Source Nodes: [setitem_1, view_2, value_states_1], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_1 => index_put_1, select_7, select_8, view_4
#   value_states_1 => permute_5
#   view_2 => view_2
# Graph fragment:
#   %buf5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf5]
#   %copy_ : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %select_int : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg1_1, 0, 0), kwargs = {})
#   %select_scatter_default : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int, %index_put, 0, 0), kwargs = {})
#   %select_scatter_default_1 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%arg1_1, %select_scatter_default, 0, 0), kwargs = {})
#   %select_7 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_1, 0, 1), kwargs = {})
#   %select_8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_7, 0, 0), kwargs = {})
#   %view_2 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [1, 1, 2, 128]), kwargs = {})
#   %permute_5 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_2, [0, 2, 1, 3]), kwargs = {})
#   %view_4 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_5, [2, 1, 128]), kwargs = {})
#   %index_put_1 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_8, [None, None, %arg7_1], %view_4), kwargs = {})
#   return %index_put_1
triton_poi_fused_index_put_select_transpose_view_3 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_3', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_3(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp4 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp5 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
    tmp7 = tl.load(in_ptr1 + (x0 + 2097152*ks0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tmp1 == tmp1
    tmp6 = tl.where(tmp3, tmp4, tmp5)
    tmp8 = tl.where(tmp2, tmp6, tmp7)
    tl.store(out_ptr0 + (x0), tmp8, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/pe/cpejyk6ibhzldjmxk6aru46ly5h3vucbokcizk34oqkvwnz2rg4f.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf8]
#   %buf5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf5]
#   %copy_ : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %select_int : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%arg1_1, 0, 0), kwargs = {})
#   %select_scatter_default : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int, %index_put, 0, 0), kwargs = {})
#   %select_scatter_default_1 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%arg1_1, %select_scatter_default, 0, 0), kwargs = {})
#   %select_int_1 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_1, 0, 1), kwargs = {})
#   %select_scatter_default_2 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_1, %index_put_1, 0, 0), kwargs = {})
#   %select_scatter_default_3 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_1, %select_scatter_default_2, 0, 1), kwargs = {})
#   return %select_scatter_default_3
triton_poi_fused_4 = async_compile.triton('triton_poi_fused_4', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 268435456}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_4(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x3 = (xindex % ks0)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp11 = tl.load(in_ptr2 + (ks0 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp15 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 0, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tmp1 == tmp4
    tmp10 = tl.where(tmp5, tmp8, tmp9)
    tmp12 = tl.where(tmp7, tmp10, tmp11)
    tmp13 = tl.where(tmp5, tmp6, tmp12)
    tmp14 = tmp0 == tmp4
    tmp16 = tl.where(tmp14, tmp10, tmp15)
    tmp17 = tl.where(tmp2, tmp13, tmp16)
    tl.store(out_ptr0 + (x4), tmp17, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/as/cas7mjqvackhonppwch545ac6noa6uuycf4y3nxpubhmhf25yv4o.py
# Topologically Sorted Source Nodes: [view, query_states_1, attn_output], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
# Source node to ATen node mapping:
#   attn_output => convert_element_type_8, mul_8
#   query_states_1 => permute_3
#   view => view
# Graph fragment:
#   %mm : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm]
#   %view : Tensor "bf16[1, 1, 16, 128][2048, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [1, 1, 16, 128]), kwargs = {})
#   %permute_3 : Tensor "bf16[1, 16, 1, 128][2048, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_8 : Tensor "f32[1, 16, 1, 128][2048, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_3, torch.float32), kwargs = {})
#   %mul_8 : Tensor "f32[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%convert_element_type_8, 0.29730177875068026), kwargs = {})
#   return %expand_2
triton_poi_fused__to_copy_mul_transpose_view_5 = async_compile.triton('triton_poi_fused__to_copy_mul_transpose_view_5', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_mul_transpose_view_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 20480}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_mul_transpose_view_5(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/i3/ci3sgjr3yrojpp63uf5cr4g2u524be7spgcfx64kalj7f5res6ks.py
# Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output => clone, convert_element_type_9, expand, mul_9, permute_6, select_12, select_13, unsqueeze, view_5
# Graph fragment:
#   %select_scatter_default_3 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_12 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_13 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_12, 0, 0), kwargs = {})
#   %convert_element_type_9 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_13, torch.float32), kwargs = {})
#   %unsqueeze : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_9, 2), kwargs = {})
#   %expand : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand,), kwargs = {memory_format: torch.contiguous_format})
#   %view_5 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone, [1, 16, 8192, 128]), kwargs = {})
#   %permute_6 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_5, [0, 1, 3, 2]), kwargs = {})
#   %mul_9 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_6, 0.29730177875068026), kwargs = {})
#   return %expand_3
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_6 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_6', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_6(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ul/culbhlbjostmdyicvpocaf3ngtvccpzmqrkvxtyunpepfojlah2r.py
# Topologically Sorted Source Nodes: [attn_output, arange, attn_mask], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
# Source node to ATen node mapping:
#   arange => iota
#   attn_mask => le
#   attn_output => add_15, any_1, div, eq_2, full_default, full_default_1, full_default_2, logical_not, logical_not_1, view_9, where, where_1
# Graph fragment:
#   %bmm : Tensor "f32[16, 1, 8192][8192, 8192, 1]cuda:0" = PlaceHolder[target=bmm]
#   %arg7_1 : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=arg7_1]
#   %any_1 : Tensor "b8[1, 16, 1, 1][16, 1, 16, 16]cuda:0" = PlaceHolder[target=any_1]
#   %getitem_14 : Tensor "f32[1, 16, 1, 1][16, 1, 16, 16]cuda:0" = PlaceHolder[target=getitem_14]
#   %getitem_15 : Tensor "f32[1, 16, 1, 1][16, 1, 16, 16]cuda:0" = PlaceHolder[target=getitem_15]
#   %view_9 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm, [1, 16, 1, 8192]), kwargs = {})
#   %iota : Tensor "i64[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (8192,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %le : Tensor "b8[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%iota, %arg7_1), kwargs = {})
#   %full_default_1 : Tensor "bf16[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %full_default : Tensor "bf16[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], -inf), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %where : Tensor "bf16[8192][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le, %full_default_1, %full_default), kwargs = {})
#   %add_15 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_9, %where), kwargs = {})
#   %eq_2 : Tensor "b8[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.eq.Scalar](args = (%add_15, -inf), kwargs = {})
#   %logical_not : Tensor "b8[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.logical_not.default](args = (%eq_2,), kwargs = {})
#   %any_1 : Tensor "b8[1, 16, 1, 1][16, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.any.dim](args = (%logical_not, -1, True), kwargs = {})
#   %logical_not_1 : Tensor "b8[1, 16, 1, 1][16, 1, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.logical_not.default](args = (%any_1,), kwargs = {})
#   %full_default_2 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([1, 16, 1, 8192], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
#   %prepare_softmax_online_default_7 : [num_users=2] = call_function[target=torch.ops.prims.prepare_softmax_online.default](args = (%add_15, -1), kwargs = {})
#   %sub_tensor_7 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%add_15, %getitem_14), kwargs = {})
#   %exp_default_7 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_tensor_7,), kwargs = {})
#   %div : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp_default_7, %getitem_15), kwargs = {})
#   %where_1 : Tensor "f32[1, 16, 1, 8192][131072, 8192, 8192, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%logical_not_1, %full_default_2, %div), kwargs = {})
#   return %any_1,%getitem_14,%getitem_15,%expand_4
triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7 = async_compile.triton('triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 16, 'r0_': 8192},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 3, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1572864}}
)
@triton.jit
def triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7(in_out_ptr0, in_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 16
    r0_numel = 8192
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = xindex
    tmp1 = tl.load(in_ptr0 + (0))
    tmp2 = tl.broadcast_to(tmp1, [1, 1])
    _tmp15 = tl.full([XBLOCK, R0_BLOCK], False, tl.int1)
    _tmp19_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
    _tmp19_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp0 = tl.load(in_out_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
        tmp3 = r0_1
        tmp4 = tmp3 <= tmp2
        tmp5 = tl.full([1, 1], 0.0, tl.float32)
        tmp6 = tl.full([1, 1], float("-inf"), tl.float32)
        tmp7 = tl.where(tmp4, tmp5, tmp6)
        tmp8 = tmp7.to(tl.float32)
        tmp9 = tmp0 + tmp8
        tmp10 = tmp9 == tmp6
        tmp11 = tmp10 == 0
        tmp12 = tmp11.to(tl.int64)
        tmp13 = (tmp12 != 0)
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
        tmp16 = _tmp15 | tmp14
        _tmp15 = tl.where(r0_mask & xmask, tmp16, _tmp15)
        tmp18 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])

        _tmp19_max_next, _tmp19_sum_next = triton_helpers.online_softmax_combine(
            _tmp19_max, _tmp19_sum, tmp18, False
        )

        _tmp19_max = tl.where(r0_mask & xmask, _tmp19_max_next, _tmp19_max)
        _tmp19_sum = tl.where(r0_mask & xmask, _tmp19_sum_next, _tmp19_sum)
    tmp17 = _tmp15.to(tl.int8)
    tmp15 = triton_helpers.any(tmp17, 1)[:, None]

    tmp19, tmp20 = triton_helpers.online_softmax_reduce(
        _tmp19_max, _tmp19_sum, 1, False)
    tmp19 = tmp19[:, None]
    tmp20 = tmp20[:, None]
    tmp23 = tl.load(in_ptr0 + (0))
    tmp24 = tl.broadcast_to(tmp23, [1, 1])
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp22 = tl.load(in_out_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
        tmp21 = tmp15 == 0
        tmp25 = r0_1
        tmp26 = tmp25 <= tmp24
        tmp27 = tl.full([1, 1], 0.0, tl.float32)
        tmp28 = tl.full([1, 1], float("-inf"), tl.float32)
        tmp29 = tl.where(tmp26, tmp27, tmp28)
        tmp30 = tmp29.to(tl.float32)
        tmp31 = tmp22 + tmp30
        tmp32 = tmp31 - tmp19
        tmp33 = libdevice.exp(tmp32)
        tmp34 = (tmp33 / tmp20)
        tmp35 = tl.where(tmp21, tmp27, tmp34)
        tl.store(in_out_ptr0 + (r0_1 + 8192*x0), tmp35, r0_mask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/xj/cxjmnkzqy4bg5myypjlbwh5a723q44yt2kwmzaumupvufuqtyvhg.py
# Topologically Sorted Source Nodes: [setitem_1, attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output => clone_1, convert_element_type_10, expand_1, unsqueeze_1
#   setitem_1 => select_10, select_11
# Graph fragment:
#   %select_scatter_default_3 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_10 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 1), kwargs = {})
#   %select_11 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_10, 0, 0), kwargs = {})
#   %convert_element_type_10 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_11, torch.float32), kwargs = {})
#   %unsqueeze_1 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_10, 2), kwargs = {})
#   %expand_1 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_1, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_1 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_1,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_1
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_8 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_8', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_8(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/jw/cjw32t4wuof3anjr55vs5cwk5s4ybxgmccpf4pxc4ueqwaqhfb3t.py
# Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.view, aten._to_copy]
# Source node to ATen node mapping:
#   attn_output => convert_element_type_11, view_12
# Graph fragment:
#   %bmm_1 : Tensor "f32[16, 1, 128][128, 128, 1]cuda:0" = PlaceHolder[target=bmm_1]
#   %view_12 : Tensor "f32[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%bmm_1, [1, 16, 1, 128]), kwargs = {})
#   %convert_element_type_11 : Tensor "bf16[1, 16, 1, 128][2048, 128, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_12, torch.bfloat16), kwargs = {})
#   return %convert_element_type_11
triton_poi_fused__to_copy_view_9 = async_compile.triton('triton_poi_fused__to_copy_view_9', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_view_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16384}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_view_9(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/z6/cz67zft6jfhbnspm25arijg6r4yvv4righ4dba57hk44pkxqun36.py
# Topologically Sorted Source Nodes: [hidden_states_1, to_2, pow_2, variance_1, add_2, rsqrt_1, mul_2, hidden_1, hidden_states_2], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_2 => add_17
#   hidden_1 => convert_element_type_16
#   hidden_states_1 => add_16
#   hidden_states_2 => mul_11
#   mul_2 => mul_10
#   pow_2 => pow_2
#   rsqrt_1 => rsqrt_1
#   to_2 => convert_element_type_15
#   variance_1 => mean_1
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %buf21 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf21]
#   %arg9_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg9_1]
#   %add_16 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %mm_3), kwargs = {})
#   %convert_element_type_15 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_16, torch.float32), kwargs = {})
#   %pow_2 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_15, 2), kwargs = {})
#   %mean_1 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
#   %add_17 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
#   %rsqrt_1 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_17,), kwargs = {})
#   %mul_10 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_16, %rsqrt_1), kwargs = {})
#   %convert_element_type_16 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_10, torch.bfloat16), kwargs = {})
#   %mul_11 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_16, %arg9_1), kwargs = {})
#   return %buf21,%mul_11
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 20480}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp6 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp3 * tmp3
        tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
        tmp7 = _tmp6 + tmp5
        _tmp6 = tl.where(r0_mask, tmp7, _tmp6)
    tmp6 = tl.sum(_tmp6, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp8 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp9 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp19 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp10 = tmp8 + tmp9
        tmp11 = tmp10.to(tl.float32)
        tmp12 = tl.full([1, 1], 2048.0, tl.float32)
        tmp13 = (tmp6 / tmp12)
        tmp14 = tl.full([1, 1], 1e-05, tl.float32)
        tmp15 = tmp13 + tmp14
        tmp16 = libdevice.rsqrt(tmp15)
        tmp17 = tmp11 * tmp16
        tmp18 = tmp17.to(tl.float32)
        tmp20 = tmp18 * tmp19
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp20, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/6l/c6ljcjnehvln6c6qpx3vfiuach6czhl7t4yotfvuhcuu2mri32ee.py
# Topologically Sorted Source Nodes: [silu, mul_4], Original ATen: [aten.silu, aten.mul]
# Source node to ATen node mapping:
#   mul_4 => mul_12
#   silu => add_18, convert_element_type_19, convert_element_type_20, div_1, exp_1, neg
# Graph fragment:
#   %mm_4 : Tensor "bf16[1, 6144][6144, 1]cuda:0" = PlaceHolder[target=mm_4]
#   %mm_5 : Tensor "bf16[1, 6144][6144, 1]cuda:0" = PlaceHolder[target=mm_5]
#   %convert_element_type_19 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_4, torch.float32), kwargs = {})
#   %neg : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%convert_element_type_19,), kwargs = {})
#   %exp_1 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%neg,), kwargs = {})
#   %add_18 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%exp_1, 1), kwargs = {})
#   %div_1 : Tensor "f32[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_19, %add_18), kwargs = {})
#   %convert_element_type_20 : Tensor "bf16[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%div_1, torch.bfloat16), kwargs = {})
#   %mul_12 : Tensor "bf16[1, 6144][6144, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_20, %mm_5), kwargs = {})
#   return %mul_12
triton_poi_fused_mul_silu_11 = async_compile.triton('triton_poi_fused_mul_silu_11', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 8192}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_11', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 49152}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_silu_11(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6144
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
    tmp8 = tl.load(in_ptr0 + (x0), xmask).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = -tmp1
    tmp3 = libdevice.exp(tmp2)
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = (tmp1 / tmp5)
    tmp7 = tmp6.to(tl.float32)
    tmp9 = tmp7 * tmp8
    tl.store(in_out_ptr0 + (x0), tmp9, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qc/cqc26n32rprli5aetengorxkzfj2vhrjiqovg42epstwcloztyhm.py
# Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, to_4, pow_3, variance_2, add_4, rsqrt_2, mul_5, hidden_2, hidden_states_5], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_4 => add_33
#   hidden_2 => convert_element_type_26
#   hidden_states_1 => add_16
#   hidden_states_4 => add_19
#   hidden_states_5 => mul_19
#   mul_5 => mul_18
#   pow_3 => pow_3
#   rsqrt_2 => rsqrt_2
#   to_4 => convert_element_type_25
#   variance_2 => mean_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %buf27 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf27]
#   %arg13_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg13_1]
#   %add_16 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %mm_3), kwargs = {})
#   %add_19 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_16, %mm_6), kwargs = {})
#   %convert_element_type_25 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_19, torch.float32), kwargs = {})
#   %pow_3 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_25, 2), kwargs = {})
#   %mean_2 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
#   %add_33 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
#   %rsqrt_2 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_33,), kwargs = {})
#   %mul_18 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_19, %rsqrt_2), kwargs = {})
#   %convert_element_type_26 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_18, torch.bfloat16), kwargs = {})
#   %mul_19 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_26, %arg13_1), kwargs = {})
#   return %buf27,%mul_19
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 24576}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp5 = tmp4.to(tl.float32)
        tmp6 = tmp5 * tmp5
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
        tmp9 = _tmp8 + tmp7
        _tmp8 = tl.where(r0_mask, tmp9, _tmp8)
    tmp8 = tl.sum(_tmp8, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp10 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp11 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp12 = tmp10 + tmp11
        tmp14 = tmp12 + tmp13
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp8 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ot/cotvxmoqekn2btajce62wr5uhjs736fpz7kp75rv5vmhekihdcvw.py
# Topologically Sorted Source Nodes: [setitem_2, view_4, key_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_states_3 => permute_16
#   setitem_2 => index_put_2, select_20, select_21, view_17
#   view_4 => view_15
# Graph fragment:
#   %select_scatter_default_3 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_20 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_21 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_20, 0, 1), kwargs = {})
#   %view_15 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_8, [1, 1, 2, 128]), kwargs = {})
#   %permute_16 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_15, [0, 2, 1, 3]), kwargs = {})
#   %view_17 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_16, [2, 1, 128]), kwargs = {})
#   %index_put_2 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_21, [None, None, %arg7_1], %view_17), kwargs = {})
#   return %index_put_2
triton_poi_fused_index_put_select_transpose_view_13 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_13', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_13(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (2097152 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/l7/cl7rzlaqffo4idlg2sgbk4orgzjyaahthju6cn4h4jdnnszwqboz.py
# Topologically Sorted Source Nodes: [setitem_3, view_5, value_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_3 => index_put_3, select_25, select_26, view_18
#   value_states_3 => permute_17
#   view_5 => view_16
# Graph fragment:
#   %buf32 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf32]
#   %select_scatter_default_3 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_int_2 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_scatter_default_4 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_2, %index_put_2, 0, 1), kwargs = {})
#   %select_scatter_default_5 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_3, %select_scatter_default_4, 0, 0), kwargs = {})
#   %select_25 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_5, 0, 1), kwargs = {})
#   %select_26 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_25, 0, 1), kwargs = {})
#   %view_16 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_9, [1, 1, 2, 128]), kwargs = {})
#   %permute_17 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_16, [0, 2, 1, 3]), kwargs = {})
#   %view_18 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_17, [2, 1, 128]), kwargs = {})
#   %index_put_3 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_26, [None, None, %arg7_1], %view_18), kwargs = {})
#   return %index_put_3
triton_poi_fused_index_put_select_transpose_view_14 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_14', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_14', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_14(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp4 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp5 = tl.load(in_ptr1 + (2097152 + x0), None).to(tl.float32)
    tmp7 = tl.load(in_ptr1 + (2097152 + ks0 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tmp0 == tmp0
    tmp6 = tl.where(tmp3, tmp4, tmp5)
    tmp8 = tl.where(tmp2, tmp6, tmp7)
    tl.store(out_ptr0 + (x0), tmp8, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/oy/coyikyd4qhdhfk3mxspik3cmevuklz6ox7bbpcepaz6iwjpxkvbb.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf35 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf35]
#   %buf32 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf32]
#   %select_scatter_default_3 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_int_2 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_3, 0, 0), kwargs = {})
#   %select_scatter_default_4 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_2, %index_put_2, 0, 1), kwargs = {})
#   %select_scatter_default_5 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_3, %select_scatter_default_4, 0, 0), kwargs = {})
#   %select_int_3 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_5, 0, 1), kwargs = {})
#   %select_scatter_default_6 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_3, %index_put_3, 0, 1), kwargs = {})
#   %select_scatter_default_7 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_5, %select_scatter_default_6, 0, 1), kwargs = {})
#   return %select_scatter_default_7
triton_poi_fused_15 = async_compile.triton('triton_poi_fused_15', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 268435456}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_15(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x3 = (xindex % ks0)
    x4 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp11 = tl.load(in_ptr2 + (ks0 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp15 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tmp3 == tmp1
    tmp6 = tl.full([1], 0, tl.int32)
    tmp7 = tmp1 == tmp6
    tmp10 = tl.where(tmp4, tmp8, tmp9)
    tmp12 = tl.where(tmp7, tmp10, tmp11)
    tmp13 = tl.where(tmp4, tmp5, tmp12)
    tmp14 = tmp0 == tmp6
    tmp16 = tl.where(tmp14, tmp10, tmp15)
    tmp17 = tl.where(tmp2, tmp13, tmp16)
    tl.store(out_ptr0 + (x4), tmp17, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/no/cnog46ty6yiushpgmvahwvuoa2wa2ybmembjfqgbmd2rojdxvgnm.py
# Topologically Sorted Source Nodes: [attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_4 => clone_2, convert_element_type_34, expand_6, mul_22, permute_18, select_30, select_31, unsqueeze_2, view_19
# Graph fragment:
#   %select_scatter_default_7 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_30 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_31 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_30, 0, 1), kwargs = {})
#   %convert_element_type_34 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_31, torch.float32), kwargs = {})
#   %unsqueeze_2 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_34, 2), kwargs = {})
#   %expand_6 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_2, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_2 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_6,), kwargs = {memory_format: torch.contiguous_format})
#   %view_19 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_2, [1, 16, 8192, 128]), kwargs = {})
#   %permute_18 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_19, [0, 1, 3, 2]), kwargs = {})
#   %mul_22 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_18, 0.29730177875068026), kwargs = {})
#   return %expand_9
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_16 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_16', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_16(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (2097152 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/jt/cjt7apskhg4h5xhdxaebnu6finrumt3pq34xtdux4dl2zg3j5i2p.py
# Topologically Sorted Source Nodes: [setitem_3, attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_4 => clone_3, convert_element_type_35, expand_7, unsqueeze_3
#   setitem_3 => select_28, select_29
# Graph fragment:
#   %select_scatter_default_7 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_28 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 1), kwargs = {})
#   %select_29 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_28, 0, 1), kwargs = {})
#   %convert_element_type_35 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_29, torch.float32), kwargs = {})
#   %unsqueeze_3 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_35, 2), kwargs = {})
#   %expand_7 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_3, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_3 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_7,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_3
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_17 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_17', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_17(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (2097152 + ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/32/c3275lkzlazmbwqwxhbco6fljnnqh7iypnqq3csszfque52lwmpe.py
# Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, to_6, pow_4, variance_3, add_6, rsqrt_3, mul_7, hidden_3, hidden_states_7], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_6 => add_37
#   hidden_3 => convert_element_type_41
#   hidden_states_1 => add_16
#   hidden_states_4 => add_19
#   hidden_states_6 => add_36
#   hidden_states_7 => mul_24
#   mul_7 => mul_23
#   pow_4 => pow_4
#   rsqrt_3 => rsqrt_3
#   to_6 => convert_element_type_40
#   variance_3 => mean_3
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %buf48 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf48]
#   %arg18_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg18_1]
#   %add_16 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %mm_3), kwargs = {})
#   %add_19 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_16, %mm_6), kwargs = {})
#   %add_36 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_19, %mm_10), kwargs = {})
#   %convert_element_type_40 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_36, torch.float32), kwargs = {})
#   %pow_4 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_40, 2), kwargs = {})
#   %mean_3 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
#   %add_37 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
#   %rsqrt_3 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_37,), kwargs = {})
#   %mul_23 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_36, %rsqrt_3), kwargs = {})
#   %convert_element_type_41 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_23, torch.bfloat16), kwargs = {})
#   %mul_24 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_41, %arg18_1), kwargs = {})
#   return %buf48,%mul_24
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 9, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 28672}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp7 = tmp6.to(tl.float32)
        tmp8 = tmp7 * tmp7
        tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
        tmp11 = _tmp10 + tmp9
        _tmp10 = tl.where(r0_mask, tmp11, _tmp10)
    tmp10 = tl.sum(_tmp10, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp12 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp17 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp27 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp14 = tmp12 + tmp13
        tmp16 = tmp14 + tmp15
        tmp18 = tmp16 + tmp17
        tmp19 = tmp18.to(tl.float32)
        tmp20 = tl.full([1, 1], 2048.0, tl.float32)
        tmp21 = (tmp10 / tmp20)
        tmp22 = tl.full([1, 1], 1e-05, tl.float32)
        tmp23 = tmp21 + tmp22
        tmp24 = libdevice.rsqrt(tmp23)
        tmp25 = tmp19 * tmp24
        tmp26 = tmp25.to(tl.float32)
        tmp28 = tmp26 * tmp27
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp28, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/o7/co77drt5vttusjlhdqp44e2kzb62ppr43qaxukqprnw6q3gu6dn6.py
# Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, hidden_states_9, to_8, pow_5, variance_4, add_8, rsqrt_4, mul_10, hidden_4, hidden_states_10], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_8 => add_53
#   hidden_4 => convert_element_type_51
#   hidden_states_1 => add_16
#   hidden_states_10 => mul_32
#   hidden_states_4 => add_19
#   hidden_states_6 => add_36
#   hidden_states_9 => add_39
#   mul_10 => mul_31
#   pow_5 => pow_5
#   rsqrt_4 => rsqrt_4
#   to_8 => convert_element_type_50
#   variance_4 => mean_4
# Graph fragment:
#   %arg3_1 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %mm_3 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %mm_13 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %add_39 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_39]
#   %buf55 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf55]
#   %arg22_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg22_1]
#   %add_16 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %mm_3), kwargs = {})
#   %add_19 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_16, %mm_6), kwargs = {})
#   %add_36 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_19, %mm_10), kwargs = {})
#   %add_39 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_36, %mm_13), kwargs = {})
#   %convert_element_type_50 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_39, torch.float32), kwargs = {})
#   %pow_5 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_50, 2), kwargs = {})
#   %mean_4 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
#   %add_53 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
#   %rsqrt_4 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_53,), kwargs = {})
#   %mul_31 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_39, %rsqrt_4), kwargs = {})
#   %convert_element_type_51 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_31, torch.bfloat16), kwargs = {})
#   %mul_32 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_51, %arg22_1), kwargs = {})
#   return %add_39,%buf55,%mul_32
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 40960}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp8 = tmp6 + tmp7
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp9 * tmp9
        tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
        tmp13 = _tmp12 + tmp11
        _tmp12 = tl.where(r0_mask, tmp13, _tmp12)
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp8, r0_mask)
    tmp12 = tl.sum(_tmp12, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp14 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp12 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/j5/cj5b33tbq3c6vy23u2oensswrib5mgfozkjspd7jflgxqexzgslf.py
# Topologically Sorted Source Nodes: [setitem_4, view_7, key_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_states_5 => permute_28
#   setitem_4 => index_put_4, select_38, select_39, view_31
#   view_7 => view_29
# Graph fragment:
#   %select_scatter_default_7 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_38 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_39 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_38, 0, 2), kwargs = {})
#   %view_29 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_15, [1, 1, 2, 128]), kwargs = {})
#   %permute_28 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_29, [0, 2, 1, 3]), kwargs = {})
#   %view_31 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_28, [2, 1, 128]), kwargs = {})
#   %index_put_4 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_39, [None, None, %arg7_1], %view_31), kwargs = {})
#   return %index_put_4
triton_poi_fused_index_put_select_transpose_view_20 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_20', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_20', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_20(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (4194304 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/hv/chvgjswwd7ghkwkcgmsoksuii5xysbluv3lkmofooggmnrwz6oau.py
# Topologically Sorted Source Nodes: [setitem_5, view_8, value_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_5 => index_put_5, select_43, select_44, view_32
#   value_states_5 => permute_29
#   view_8 => view_30
# Graph fragment:
#   %buf60 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf60]
#   %select_scatter_default_7 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_int_4 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_scatter_default_8 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_4, %index_put_4, 0, 2), kwargs = {})
#   %select_scatter_default_9 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_7, %select_scatter_default_8, 0, 0), kwargs = {})
#   %select_43 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_9, 0, 1), kwargs = {})
#   %select_44 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_43, 0, 2), kwargs = {})
#   %view_30 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_16, [1, 1, 2, 128]), kwargs = {})
#   %permute_29 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_30, [0, 2, 1, 3]), kwargs = {})
#   %view_32 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_29, [2, 1, 128]), kwargs = {})
#   %index_put_5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_44, [None, None, %arg7_1], %view_32), kwargs = {})
#   return %index_put_5
triton_poi_fused_index_put_select_transpose_view_21 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_21', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_21(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (4194304 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (4194304 + ks0 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 2, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/fy/cfytqwdkhtlb3hvwif6ykszc623xkysxjrogrztpz5h24d2rcfao.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf63 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf63]
#   %buf60 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf60]
#   %select_scatter_default_7 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_7]
#   %select_int_4 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_7, 0, 0), kwargs = {})
#   %select_scatter_default_8 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_4, %index_put_4, 0, 2), kwargs = {})
#   %select_scatter_default_9 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_7, %select_scatter_default_8, 0, 0), kwargs = {})
#   %select_int_5 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_9, 0, 1), kwargs = {})
#   %select_scatter_default_10 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_5, %index_put_5, 0, 2), kwargs = {})
#   %select_scatter_default_11 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_9, %select_scatter_default_10, 0, 1), kwargs = {})
#   return %select_scatter_default_11
triton_poi_fused_22 = async_compile.triton('triton_poi_fused_22', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 268435456}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_22(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x3 = (xindex % ks0)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (ks0 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 2, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/q2/cq2j65ygztzyarcrd7mqu5jbcbnb5pwjz52ve7qlrea7acyweper.py
# Topologically Sorted Source Nodes: [attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_8 => clone_4, convert_element_type_59, expand_12, mul_35, permute_30, select_48, select_49, unsqueeze_4, view_33
# Graph fragment:
#   %select_scatter_default_11 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_48 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_49 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_48, 0, 2), kwargs = {})
#   %convert_element_type_59 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_49, torch.float32), kwargs = {})
#   %unsqueeze_4 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_59, 2), kwargs = {})
#   %expand_12 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_4, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_4 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_12,), kwargs = {memory_format: torch.contiguous_format})
#   %view_33 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_4, [1, 16, 8192, 128]), kwargs = {})
#   %permute_30 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_33, [0, 1, 3, 2]), kwargs = {})
#   %mul_35 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_30, 0.29730177875068026), kwargs = {})
#   return %expand_15
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_23 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_23', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_23(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (4194304 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ki/ckiqfsbcmmoqrspubgsm63hjariniwsjbwbjfy4ur4lde3ftvlof.py
# Topologically Sorted Source Nodes: [setitem_5, attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_8 => clone_5, convert_element_type_60, expand_13, unsqueeze_5
#   setitem_5 => select_46, select_47
# Graph fragment:
#   %select_scatter_default_11 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_46 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 1), kwargs = {})
#   %select_47 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_46, 0, 2), kwargs = {})
#   %convert_element_type_60 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_47, torch.float32), kwargs = {})
#   %unsqueeze_5 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_60, 2), kwargs = {})
#   %expand_13 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_5, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_5 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_13,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_5
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_24 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_24', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_24(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (4194304 + ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ju/cjuhspy2vbzqlcflnkfe6ihkjzkpd47naaknluhyfbp5ytziomh5.py
# Topologically Sorted Source Nodes: [setitem_6, view_10, key_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_states_7 => permute_40
#   setitem_6 => index_put_6, select_56, select_57, view_45
#   view_10 => view_43
# Graph fragment:
#   %select_scatter_default_11 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_56 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_57 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_56, 0, 3), kwargs = {})
#   %view_43 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_22, [1, 1, 2, 128]), kwargs = {})
#   %permute_40 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_43, [0, 2, 1, 3]), kwargs = {})
#   %view_45 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_40, [2, 1, 128]), kwargs = {})
#   %index_put_6 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_57, [None, None, %arg7_1], %view_45), kwargs = {})
#   return %index_put_6
triton_poi_fused_index_put_select_transpose_view_25 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_25', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_25(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (6291456 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/rg/crgrqqqzqev73zhxb6sgd6wt7nhaifxkd6z5btm6nnuclcyt7kyz.py
# Topologically Sorted Source Nodes: [setitem_7, view_11, value_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_7 => index_put_7, select_61, select_62, view_46
#   value_states_7 => permute_41
#   view_11 => view_44
# Graph fragment:
#   %buf87 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf87]
#   %select_scatter_default_11 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_int_6 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_scatter_default_12 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_6, %index_put_6, 0, 3), kwargs = {})
#   %select_scatter_default_13 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_11, %select_scatter_default_12, 0, 0), kwargs = {})
#   %select_61 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_13, 0, 1), kwargs = {})
#   %select_62 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_61, 0, 3), kwargs = {})
#   %view_44 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_23, [1, 1, 2, 128]), kwargs = {})
#   %permute_41 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_44, [0, 2, 1, 3]), kwargs = {})
#   %view_46 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_41, [2, 1, 128]), kwargs = {})
#   %index_put_7 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_62, [None, None, %arg7_1], %view_46), kwargs = {})
#   return %index_put_7
triton_poi_fused_index_put_select_transpose_view_26 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_26', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_26', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_26(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (6291456 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (6291456 + ks0 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 3, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/cm/ccm25fpon32mdisipw6gtdepgvqkla5nte4jmxwpakeolvy3ezso.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf90 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf90]
#   %buf87 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf87]
#   %select_scatter_default_11 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_11]
#   %select_int_6 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_11, 0, 0), kwargs = {})
#   %select_scatter_default_12 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_6, %index_put_6, 0, 3), kwargs = {})
#   %select_scatter_default_13 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_11, %select_scatter_default_12, 0, 0), kwargs = {})
#   %select_int_7 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_13, 0, 1), kwargs = {})
#   %select_scatter_default_14 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_7, %index_put_7, 0, 3), kwargs = {})
#   %select_scatter_default_15 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_13, %select_scatter_default_14, 0, 1), kwargs = {})
#   return %select_scatter_default_15
triton_poi_fused_27 = async_compile.triton('triton_poi_fused_27', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 268435456}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_27(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x3 = (xindex % ks0)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (ks0 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 3, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/fs/cfszqczajovb2me2fwi4ue35eompbujuivok3paa26ce6lidghgz.py
# Topologically Sorted Source Nodes: [attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_12 => clone_6, convert_element_type_84, expand_18, mul_48, permute_42, select_66, select_67, unsqueeze_6, view_47
# Graph fragment:
#   %select_scatter_default_15 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_66 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_67 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_66, 0, 3), kwargs = {})
#   %convert_element_type_84 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_67, torch.float32), kwargs = {})
#   %unsqueeze_6 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_84, 2), kwargs = {})
#   %expand_18 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_6, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_6 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_18,), kwargs = {memory_format: torch.contiguous_format})
#   %view_47 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_6, [1, 16, 8192, 128]), kwargs = {})
#   %permute_42 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_47, [0, 1, 3, 2]), kwargs = {})
#   %mul_48 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_42, 0.29730177875068026), kwargs = {})
#   return %expand_21
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_28 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_28', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_28(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (6291456 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7w/c7wcxndnvhlkn2te33ptpb5djpo3gsvhikxgcfdbhdlmfxjkpu3p.py
# Topologically Sorted Source Nodes: [setitem_7, attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_12 => clone_7, convert_element_type_85, expand_19, unsqueeze_7
#   setitem_7 => select_64, select_65
# Graph fragment:
#   %select_scatter_default_15 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_64 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 1), kwargs = {})
#   %select_65 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_64, 0, 3), kwargs = {})
#   %convert_element_type_85 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_65, torch.float32), kwargs = {})
#   %unsqueeze_7 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_85, 2), kwargs = {})
#   %expand_19 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_7, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_7 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_19,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_7
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_29 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_29', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_29(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (6291456 + ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ph/cph4nxhroostqmhqevt7ltzzptgoafyqysqelkgyt4j3k6mbc76s.py
# Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, hidden_states_16, hidden_states_19, to_16, pow_9, variance_8, add_16, rsqrt_8, mul_20, hidden_8, hidden_states_20], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_16 => add_93
#   hidden_8 => convert_element_type_101
#   hidden_states_11 => add_56
#   hidden_states_14 => add_59
#   hidden_states_16 => add_76
#   hidden_states_19 => add_79
#   hidden_states_20 => mul_58
#   mul_20 => mul_57
#   pow_9 => pow_9
#   rsqrt_8 => rsqrt_8
#   to_16 => convert_element_type_100
#   variance_8 => mean_8
# Graph fragment:
#   %add_39 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_39]
#   %mm_17 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_17]
#   %mm_20 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_20]
#   %mm_24 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_24]
#   %mm_27 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_27]
#   %add_79 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_79]
#   %buf110 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf110]
#   %arg40_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg40_1]
#   %add_56 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_39, %mm_17), kwargs = {})
#   %add_59 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_56, %mm_20), kwargs = {})
#   %add_76 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_59, %mm_24), kwargs = {})
#   %add_79 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_76, %mm_27), kwargs = {})
#   %convert_element_type_100 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_79, torch.float32), kwargs = {})
#   %pow_9 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_100, 2), kwargs = {})
#   %mean_8 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_9, [-1], True), kwargs = {})
#   %add_93 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_8, 1e-05), kwargs = {})
#   %rsqrt_8 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_93,), kwargs = {})
#   %mul_57 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_79, %rsqrt_8), kwargs = {})
#   %convert_element_type_101 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_57, torch.bfloat16), kwargs = {})
#   %mul_58 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_101, %arg40_1), kwargs = {})
#   return %add_79,%buf110,%mul_58
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_30 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_30', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_30', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 40960}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_30(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp8 = tmp6 + tmp7
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp9 * tmp9
        tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
        tmp13 = _tmp12 + tmp11
        _tmp12 = tl.where(r0_mask, tmp13, _tmp12)
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp8, r0_mask)
    tmp12 = tl.sum(_tmp12, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp14 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp12 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(out_ptr1 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yk/cykiomcenwrbagcjvxjuw6h4fcnajhupibsz2epp3fhpkourskwv.py
# Topologically Sorted Source Nodes: [setitem_8, view_13, key_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_states_9 => permute_52
#   setitem_8 => index_put_8, select_74, select_75, view_59
#   view_13 => view_57
# Graph fragment:
#   %select_scatter_default_15 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_74 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_75 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_74, 0, 4), kwargs = {})
#   %view_57 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_29, [1, 1, 2, 128]), kwargs = {})
#   %permute_52 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_57, [0, 2, 1, 3]), kwargs = {})
#   %view_59 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_52, [2, 1, 128]), kwargs = {})
#   %index_put_8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_75, [None, None, %arg7_1], %view_59), kwargs = {})
#   return %index_put_8
triton_poi_fused_index_put_select_transpose_view_31 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_31', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_31', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_31(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (8388608 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4x/c4xswot4jpvo3gw2bceywvz6c2jfahzfj42sxp5mf6bnf66qch7q.py
# Topologically Sorted Source Nodes: [setitem_9, view_14, value_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_9 => index_put_9, select_79, select_80, view_60
#   value_states_9 => permute_53
#   view_14 => view_58
# Graph fragment:
#   %buf115 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf115]
#   %select_scatter_default_15 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_int_8 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_scatter_default_16 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_8, %index_put_8, 0, 4), kwargs = {})
#   %select_scatter_default_17 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_15, %select_scatter_default_16, 0, 0), kwargs = {})
#   %select_79 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_17, 0, 1), kwargs = {})
#   %select_80 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_79, 0, 4), kwargs = {})
#   %view_58 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_30, [1, 1, 2, 128]), kwargs = {})
#   %permute_53 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_58, [0, 2, 1, 3]), kwargs = {})
#   %view_60 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_53, [2, 1, 128]), kwargs = {})
#   %index_put_9 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_80, [None, None, %arg7_1], %view_60), kwargs = {})
#   return %index_put_9
triton_poi_fused_index_put_select_transpose_view_32 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_32', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_32(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (8388608 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (8388608 + ks0 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 4, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4a/c4a2pocuedlsnenhrtcr2wz5zcoe5opa4hcfyo5lrgzwucwsnwly.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf118 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf118]
#   %buf115 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf115]
#   %select_scatter_default_15 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_15]
#   %select_int_8 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_15, 0, 0), kwargs = {})
#   %select_scatter_default_16 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_8, %index_put_8, 0, 4), kwargs = {})
#   %select_scatter_default_17 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_15, %select_scatter_default_16, 0, 0), kwargs = {})
#   %select_int_9 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_17, 0, 1), kwargs = {})
#   %select_scatter_default_18 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_9, %index_put_9, 0, 4), kwargs = {})
#   %select_scatter_default_19 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_17, %select_scatter_default_18, 0, 1), kwargs = {})
#   return %select_scatter_default_19
triton_poi_fused_33 = async_compile.triton('triton_poi_fused_33', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 268435456}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_33(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x3 = (xindex % ks0)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (ks0 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 4, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/hd/chdscbdmhmbuelwkilmeuhz4wkbcfzcp3sx3bqmfto7cv26edfe7.py
# Topologically Sorted Source Nodes: [attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_16 => clone_8, convert_element_type_109, expand_24, mul_61, permute_54, select_84, select_85, unsqueeze_8, view_61
# Graph fragment:
#   %select_scatter_default_19 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_84 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_85 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_84, 0, 4), kwargs = {})
#   %convert_element_type_109 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_85, torch.float32), kwargs = {})
#   %unsqueeze_8 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_109, 2), kwargs = {})
#   %expand_24 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_8, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_8 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_24,), kwargs = {memory_format: torch.contiguous_format})
#   %view_61 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_8, [1, 16, 8192, 128]), kwargs = {})
#   %permute_54 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_61, [0, 1, 3, 2]), kwargs = {})
#   %mul_61 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_54, 0.29730177875068026), kwargs = {})
#   return %expand_27
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_34 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_34', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_34(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (8388608 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/f2/cf2i27zatsujymjfkivbcbx2ex4an2qlogqtvnvypvvyvyjdu5w4.py
# Topologically Sorted Source Nodes: [setitem_9, attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_16 => clone_9, convert_element_type_110, expand_25, unsqueeze_9
#   setitem_9 => select_82, select_83
# Graph fragment:
#   %select_scatter_default_19 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_82 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 1), kwargs = {})
#   %select_83 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_82, 0, 4), kwargs = {})
#   %convert_element_type_110 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_83, torch.float32), kwargs = {})
#   %unsqueeze_9 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_110, 2), kwargs = {})
#   %expand_25 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_9, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_9 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_25,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_9
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_35 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_35', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_35(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (8388608 + ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/uj/cujnj36gyhocsbiqbjd45vlnmhbgsbyh55qgjcebqz3g2lsdpruu.py
# Topologically Sorted Source Nodes: [setitem_10, view_16, key_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_states_11 => permute_64
#   setitem_10 => index_put_10, select_92, select_93, view_73
#   view_16 => view_71
# Graph fragment:
#   %select_scatter_default_19 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_92 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_93 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_92, 0, 5), kwargs = {})
#   %view_71 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_36, [1, 1, 2, 128]), kwargs = {})
#   %permute_64 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_71, [0, 2, 1, 3]), kwargs = {})
#   %view_73 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_64, [2, 1, 128]), kwargs = {})
#   %index_put_10 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_93, [None, None, %arg7_1], %view_73), kwargs = {})
#   return %index_put_10
triton_poi_fused_index_put_select_transpose_view_36 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_36', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_36', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_36(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (10485760 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7m/c7mc2lanvcsi7prwmigqvero5igop44cejsasy2segpwodgrbpkh.py
# Topologically Sorted Source Nodes: [setitem_11, view_17, value_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_11 => index_put_11, select_97, select_98, view_74
#   value_states_11 => permute_65
#   view_17 => view_72
# Graph fragment:
#   %buf142 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf142]
#   %select_scatter_default_19 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_int_10 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_scatter_default_20 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_10, %index_put_10, 0, 5), kwargs = {})
#   %select_scatter_default_21 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_19, %select_scatter_default_20, 0, 0), kwargs = {})
#   %select_97 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_21, 0, 1), kwargs = {})
#   %select_98 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_97, 0, 5), kwargs = {})
#   %view_72 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_37, [1, 1, 2, 128]), kwargs = {})
#   %permute_65 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_72, [0, 2, 1, 3]), kwargs = {})
#   %view_74 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_65, [2, 1, 128]), kwargs = {})
#   %index_put_11 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_98, [None, None, %arg7_1], %view_74), kwargs = {})
#   return %index_put_11
triton_poi_fused_index_put_select_transpose_view_37 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_37', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_37(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (10485760 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (10485760 + ks0 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 5, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/w6/cw6yu7gbxejzh7bo5qkfdfnf7zlyt5su75cpdl65tiezhlppui6h.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf145 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf145]
#   %buf142 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf142]
#   %select_scatter_default_19 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_19]
#   %select_int_10 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_19, 0, 0), kwargs = {})
#   %select_scatter_default_20 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_10, %index_put_10, 0, 5), kwargs = {})
#   %select_scatter_default_21 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_19, %select_scatter_default_20, 0, 0), kwargs = {})
#   %select_int_11 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_21, 0, 1), kwargs = {})
#   %select_scatter_default_22 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_11, %index_put_11, 0, 5), kwargs = {})
#   %select_scatter_default_23 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_21, %select_scatter_default_22, 0, 1), kwargs = {})
#   return %select_scatter_default_23
triton_poi_fused_38 = async_compile.triton('triton_poi_fused_38', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_38', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 268435456}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_38(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x3 = (xindex % ks0)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (ks0 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 5, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/5f/c5f4733fz6ktkdd3t2xc2neeuvr72f5oc4y6qyqbzkgquz6zcnmo.py
# Topologically Sorted Source Nodes: [attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_20 => clone_10, convert_element_type_134, expand_30, mul_74, permute_66, select_102, select_103, unsqueeze_10, view_75
# Graph fragment:
#   %select_scatter_default_23 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_102 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_103 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_102, 0, 5), kwargs = {})
#   %convert_element_type_134 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_103, torch.float32), kwargs = {})
#   %unsqueeze_10 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_134, 2), kwargs = {})
#   %expand_30 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_10, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_10 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_30,), kwargs = {memory_format: torch.contiguous_format})
#   %view_75 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_10, [1, 16, 8192, 128]), kwargs = {})
#   %permute_66 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_75, [0, 1, 3, 2]), kwargs = {})
#   %mul_74 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_66, 0.29730177875068026), kwargs = {})
#   return %expand_33
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_39 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_39', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_39(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (10485760 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/n5/cn5twycnrqwdi7x2s7s6hywxe6wai4w7wspxt2hdi25ozwzvet62.py
# Topologically Sorted Source Nodes: [setitem_11, attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_20 => clone_11, convert_element_type_135, expand_31, unsqueeze_11
#   setitem_11 => select_100, select_101
# Graph fragment:
#   %select_scatter_default_23 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_100 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 1), kwargs = {})
#   %select_101 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_100, 0, 5), kwargs = {})
#   %convert_element_type_135 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_101, torch.float32), kwargs = {})
#   %unsqueeze_11 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_135, 2), kwargs = {})
#   %expand_31 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_11, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_11 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_31,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_11
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_40 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_40', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_40', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_40(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (10485760 + ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/bc/cbcqdjmmfbzl62w5cf7r3aa4vnilmma7bnstctuwux62dlolmoes.py
# Topologically Sorted Source Nodes: [setitem_12, view_19, key_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_states_13 => permute_76
#   setitem_12 => index_put_12, select_110, select_111, view_87
#   view_19 => view_85
# Graph fragment:
#   %select_scatter_default_23 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_110 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_111 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_110, 0, 6), kwargs = {})
#   %view_85 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_43, [1, 1, 2, 128]), kwargs = {})
#   %permute_76 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_85, [0, 2, 1, 3]), kwargs = {})
#   %view_87 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_76, [2, 1, 128]), kwargs = {})
#   %index_put_12 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_111, [None, None, %arg7_1], %view_87), kwargs = {})
#   return %index_put_12
triton_poi_fused_index_put_select_transpose_view_41 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_41', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_41(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (12582912 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ec/ceckbtqpqkmt5k3vvizdnpf2f6vd3df4zyo7cgkondyu2zxvwk44.py
# Topologically Sorted Source Nodes: [setitem_13, view_20, value_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_13 => index_put_13, select_115, select_116, view_88
#   value_states_13 => permute_77
#   view_20 => view_86
# Graph fragment:
#   %buf170 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf170]
#   %select_scatter_default_23 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_int_12 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_scatter_default_24 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_12, %index_put_12, 0, 6), kwargs = {})
#   %select_scatter_default_25 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_23, %select_scatter_default_24, 0, 0), kwargs = {})
#   %select_115 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_25, 0, 1), kwargs = {})
#   %select_116 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_115, 0, 6), kwargs = {})
#   %view_86 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_44, [1, 1, 2, 128]), kwargs = {})
#   %permute_77 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_86, [0, 2, 1, 3]), kwargs = {})
#   %view_88 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_77, [2, 1, 128]), kwargs = {})
#   %index_put_13 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_116, [None, None, %arg7_1], %view_88), kwargs = {})
#   return %index_put_13
triton_poi_fused_index_put_select_transpose_view_42 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_42', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_42(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (12582912 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (12582912 + ks0 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 6, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/si/csipylghymxrlrjeapqwt2snrq2a22u6y6pksejsxvvddgmoqbpe.py
# Topologically Sorted Source Nodes: [], Original ATen: []
# Source node to ATen node mapping:
# Graph fragment:
#   %buf173 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf173]
#   %buf170 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf170]
#   %select_scatter_default_23 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_23]
#   %select_int_12 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_23, 0, 0), kwargs = {})
#   %select_scatter_default_24 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_12, %index_put_12, 0, 6), kwargs = {})
#   %select_scatter_default_25 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_23, %select_scatter_default_24, 0, 0), kwargs = {})
#   %select_int_13 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_25, 0, 1), kwargs = {})
#   %select_scatter_default_26 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_13, %index_put_13, 0, 6), kwargs = {})
#   %select_scatter_default_27 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=6] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_25, %select_scatter_default_26, 0, 1), kwargs = {})
#   return %select_scatter_default_27
triton_poi_fused_43 = async_compile.triton('triton_poi_fused_43', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 268435456}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_43(in_ptr0, in_ptr1, in_ptr2, out_ptr0, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x3 = (xindex % ks0)
    x4 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (ks0 + x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 6, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x4), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/bt/cbtdtwj3yaiggs2sxtbge7u4uygbkhdqxzqj67epyr4474rrqskv.py
# Topologically Sorted Source Nodes: [attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_24 => clone_12, convert_element_type_159, expand_36, mul_87, permute_78, select_120, select_121, unsqueeze_12, view_89
# Graph fragment:
#   %select_scatter_default_27 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_120 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_121 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_120, 0, 6), kwargs = {})
#   %convert_element_type_159 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_121, torch.float32), kwargs = {})
#   %unsqueeze_12 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_159, 2), kwargs = {})
#   %expand_36 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_12, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_12 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_36,), kwargs = {memory_format: torch.contiguous_format})
#   %view_89 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_12, [1, 16, 8192, 128]), kwargs = {})
#   %permute_78 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_89, [0, 1, 3, 2]), kwargs = {})
#   %mul_87 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_78, 0.29730177875068026), kwargs = {})
#   return %expand_39
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_44 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_44', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_44', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_44(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (12582912 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yt/cyt7zc6lvbddvvk3lg76uulfeb3kb6zrvmoserbmzaxeim4x3sm5.py
# Topologically Sorted Source Nodes: [setitem_13, attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_24 => clone_13, convert_element_type_160, expand_37, unsqueeze_13
#   setitem_13 => select_118, select_119
# Graph fragment:
#   %select_scatter_default_27 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_118 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 1), kwargs = {})
#   %select_119 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_118, 0, 6), kwargs = {})
#   %convert_element_type_160 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_119, torch.float32), kwargs = {})
#   %unsqueeze_13 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_160, 2), kwargs = {})
#   %expand_37 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_13, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_13 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_37,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_13
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_45 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_45', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_45', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_45(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (12582912 + ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/pp/cppnnxka25w3y4kwsfgefmjpbck2drvrbhexliz7xkrsxahwyi3b.py
# Topologically Sorted Source Nodes: [setitem_14, view_22, key_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   key_states_15 => permute_88
#   setitem_14 => index_put_14, select_128, select_129, view_101
#   view_22 => view_99
# Graph fragment:
#   %select_scatter_default_27 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_128 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_129 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_128, 0, 7), kwargs = {})
#   %view_99 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_50, [1, 1, 2, 128]), kwargs = {})
#   %permute_88 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_99, [0, 2, 1, 3]), kwargs = {})
#   %view_101 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_88, [2, 1, 128]), kwargs = {})
#   %index_put_14 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_129, [None, None, %arg7_1], %view_101), kwargs = {})
#   return %index_put_14
triton_poi_fused_index_put_select_transpose_view_46 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_46', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_46', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12582912}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_46(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (14680064 + x0), None).to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/zu/czuceabomutjb43hcgcsn2npgspt55z2fe6eduuzj5xi5vn3nh3s.py
# Topologically Sorted Source Nodes: [setitem_15, view_23, value_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
# Source node to ATen node mapping:
#   setitem_15 => index_put_15, select_133, select_134, view_102
#   value_states_15 => permute_89
#   view_23 => view_100
# Graph fragment:
#   %buf197 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf197]
#   %select_scatter_default_27 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_int_14 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_scatter_default_28 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_14, %index_put_14, 0, 7), kwargs = {})
#   %select_scatter_default_29 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_27, %select_scatter_default_28, 0, 0), kwargs = {})
#   %select_133 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_29, 0, 1), kwargs = {})
#   %select_134 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_133, 0, 7), kwargs = {})
#   %view_100 : Tensor "bf16[1, 1, 2, 128][256, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_51, [1, 1, 2, 128]), kwargs = {})
#   %permute_89 : Tensor "bf16[1, 2, 1, 128][256, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_100, [0, 2, 1, 3]), kwargs = {})
#   %view_102 : Tensor "bf16[2, 1, 128][128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_89, [2, 1, 128]), kwargs = {})
#   %index_put_15 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%select_134, [None, None, %arg7_1], %view_102), kwargs = {})
#   return %index_put_15
triton_poi_fused_index_put_select_transpose_view_47 = async_compile.triton('triton_poi_fused_index_put_select_transpose_view_47', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2097152}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_index_put_select_transpose_view_47', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 16777216}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_index_put_select_transpose_view_47(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2097152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp6 = tl.load(in_ptr1 + (14680064 + x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr1 + (14680064 + ks0 + x0), None).to(tl.float32)
    tmp0 = tl.full([1], 1, tl.int32)
    tmp1 = tl.full([1], 0, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = tl.full([1], 7, tl.int32)
    tmp4 = tmp3 == tmp3
    tmp7 = tl.where(tmp4, tmp5, tmp6)
    tmp9 = tl.where(tmp2, tmp7, tmp8)
    tl.store(out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7n/c7nigcu2n4wzciqjy32wn6fgy7366gqt3hrgctbjiggi3fqwskbq.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.copy_]
# Source node to ATen node mapping:
# Graph fragment:
#   %buf200 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf200]
#   %buf197 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf197]
#   %select_scatter_default_27 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_27]
#   %select_scatter_default_31 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %copy_ : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=copy_]
#   %buf5 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf5]
#   %buf8 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0" = PlaceHolder[target=buf8]
#   %select_scatter_default_3 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_3]
#   %select_int_14 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_27, 0, 0), kwargs = {})
#   %select_scatter_default_28 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_14, %index_put_14, 0, 7), kwargs = {})
#   %select_scatter_default_29 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=4] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_27, %select_scatter_default_28, 0, 0), kwargs = {})
#   %select_int_15 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_29, 0, 1), kwargs = {})
#   %select_scatter_default_30 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_int_15, %index_put_15, 0, 7), kwargs = {})
#   %select_scatter_default_31 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_default_29, %select_scatter_default_30, 0, 1), kwargs = {})
#   %copy_ : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %select_scatter_default_31), kwargs = {})
#   return %select_scatter_default_31,%buf225
triton_poi_fused_copy__48 = async_compile.triton('triton_poi_fused_copy__48', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 33554432}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*bf16', 'ks0': 'i64', 'ks1': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_copy__48', 'mutated_arg_names': ['out_ptr1'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 2, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 469762048}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_copy__48(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, ks0, ks1, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // ks0
    x1 = ((xindex // 2097152) % ks1)
    x0 = (xindex % 2097152)
    x4 = (xindex % ks0)
    x3 = xindex
    tmp6 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp9 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp10 = tl.load(in_ptr2 + (x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp12 = tl.load(in_ptr2 + (ks0 + x4), None, eviction_policy='evict_last').to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (x3), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x2
    tmp1 = tl.full([1], 1, tl.int32)
    tmp2 = tmp0 == tmp1
    tmp3 = x1
    tmp4 = tl.full([1], 7, tl.int32)
    tmp5 = tmp3 == tmp4
    tmp7 = tl.full([1], 0, tl.int32)
    tmp8 = tmp1 == tmp7
    tmp11 = tl.where(tmp5, tmp9, tmp10)
    tmp13 = tl.where(tmp8, tmp11, tmp12)
    tmp14 = tl.where(tmp5, tmp6, tmp13)
    tmp15 = tmp0 == tmp7
    tmp17 = tl.where(tmp15, tmp11, tmp16)
    tmp18 = tl.where(tmp2, tmp14, tmp17)
    tl.store(out_ptr0 + (x3), tmp18, None)
    tl.store(out_ptr1 + (x3), tmp18, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ak/cakvv3pt67tqmh7nnoxia7mcm2qgazwzhzoryy4uhgjq34oejcww.py
# Topologically Sorted Source Nodes: [attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
# Source node to ATen node mapping:
#   attn_output_28 => clone_14, convert_element_type_184, expand_42, mul_100, permute_90, select_138, select_139, unsqueeze_14, view_103
# Graph fragment:
#   %select_scatter_default_31 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %select_138 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_31, 0, 0), kwargs = {})
#   %select_139 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_138, 0, 7), kwargs = {})
#   %convert_element_type_184 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_139, torch.float32), kwargs = {})
#   %unsqueeze_14 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_184, 2), kwargs = {})
#   %expand_42 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_14, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_14 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_42,), kwargs = {memory_format: torch.contiguous_format})
#   %view_103 : Tensor "f32[1, 16, 8192, 128][16777216, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%clone_14, [1, 16, 8192, 128]), kwargs = {})
#   %permute_90 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_103, [0, 1, 3, 2]), kwargs = {})
#   %mul_100 : Tensor "f32[1, 16, 128, 8192][16777216, 1048576, 1, 128]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%permute_90, 0.29730177875068026), kwargs = {})
#   return %expand_45
triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_49 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_49', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 2048, 'x': 8192}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_49', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 33554432, 'x': 134217728}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_49(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 2048
    xnumel = 8192
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = tl.full([YBLOCK], True, tl.int1)[:, None]
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = tl.full([XBLOCK], True, tl.int1)[None, :]
    x2 = xindex
    y0 = (yindex % 128)
    y1 = yindex // 128
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (14680064 + y0 + 128*x2 + 1048576*(y1 // 8)), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.full([1, 1], 0.29730177875068026, tl.float32)
    tmp3 = tmp1 * tmp2
    tl.store(out_ptr0 + (x2 + 8192*y3), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/rg/crgjkwbesjduxmnoscqtcud46sbybvwfopvpczpsm6yn67er633o.py
# Topologically Sorted Source Nodes: [setitem_15, attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
# Source node to ATen node mapping:
#   attn_output_28 => clone_15, convert_element_type_185, expand_43, unsqueeze_15
#   setitem_15 => select_136, select_137
# Graph fragment:
#   %select_scatter_default_31 : Tensor "bf16[2, s27, 1, 2, 8192, 128][2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1]cuda:0" = PlaceHolder[target=select_scatter_default_31]
#   %select_136 : Tensor "bf16[s27, 1, 2, 8192, 128][2097152, 2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_default_31, 0, 1), kwargs = {})
#   %select_137 : Tensor "bf16[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_136, 0, 7), kwargs = {})
#   %convert_element_type_185 : Tensor "f32[1, 2, 8192, 128][2097152, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%select_137, torch.float32), kwargs = {})
#   %unsqueeze_15 : Tensor "f32[1, 2, 1, 8192, 128][2097152, 1048576, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%convert_element_type_185, 2), kwargs = {})
#   %expand_43 : Tensor "f32[1, 2, 8, 8192, 128][2097152, 1048576, 0, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%unsqueeze_15, [1, 2, 8, 8192, 128]), kwargs = {})
#   %clone_15 : Tensor "f32[1, 2, 8, 8192, 128][16777216, 8388608, 1048576, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%expand_43,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_15
triton_poi_fused__to_copy_clone_expand_select_unsqueeze_50 = async_compile.triton('triton_poi_fused__to_copy_clone_expand_select_unsqueeze_50', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 16777216}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i64', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clone_expand_select_unsqueeze_50', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 138412032}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_clone_expand_select_unsqueeze_50(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16777216
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 1048576)
    x2 = xindex // 8388608
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (14680064 + ks0 + x0 + 1048576*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ms/cmsn3yxcso5j6cszodnqx626nawkbqrlslluenjfku5litdtmmea.py
# Topologically Sorted Source Nodes: [hidden_states_31, hidden_states_34, hidden_states_36, hidden_states_39, to_32, pow_17, variance_16, add_32, rsqrt_16, mul_40, hidden_16, hidden_states_40], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_32 => add_160
#   hidden_16 => convert_element_type_201
#   hidden_states_31 => add_136
#   hidden_states_34 => add_139
#   hidden_states_36 => add_156
#   hidden_states_39 => add_159
#   hidden_states_40 => mul_105
#   mul_40 => mul_104
#   pow_17 => pow_17
#   rsqrt_16 => rsqrt_16
#   to_32 => convert_element_type_200
#   variance_16 => mean_16
# Graph fragment:
#   %add_119 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_119]
#   %mm_45 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_45]
#   %mm_48 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_48]
#   %mm_52 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_52]
#   %mm_55 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm_55]
#   %add_159 : Tensor "bf16[1, 2048][2048, 1]cuda:0" = PlaceHolder[target=add_159]
#   %buf220 : Tensor "f32[1, 1][1, 1]cuda:0" = PlaceHolder[target=buf220]
#   %arg76_1 : Tensor "bf16[2048][1]cuda:0" = PlaceHolder[target=arg76_1]
#   %add_136 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_119, %mm_45), kwargs = {})
#   %add_139 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_136, %mm_48), kwargs = {})
#   %add_156 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_139, %mm_52), kwargs = {})
#   %add_159 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_156, %mm_55), kwargs = {})
#   %convert_element_type_200 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_159, torch.float32), kwargs = {})
#   %pow_17 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_200, 2), kwargs = {})
#   %mean_16 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_17, [-1], True), kwargs = {})
#   %add_160 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_16, 1e-05), kwargs = {})
#   %rsqrt_16 : Tensor "f32[1, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_160,), kwargs = {})
#   %mul_104 : Tensor "f32[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_159, %rsqrt_16), kwargs = {})
#   %convert_element_type_201 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_104, torch.bfloat16), kwargs = {})
#   %mul_105 : Tensor "bf16[1, 2048][2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_201, %arg76_1), kwargs = {})
#   return %add_159,%buf220,%mul_105
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_51 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_51', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 2048},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_51', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 32768}}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_51(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 2048
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK], True, tl.int1)[:, None]
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp0 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp3 = tl.load(in_ptr1 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp5 = tl.load(in_ptr2 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp7 = tl.load(in_ptr3 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp4 = tmp2 + tmp3
        tmp6 = tmp4 + tmp5
        tmp8 = tmp6 + tmp7
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp9 * tmp9
        tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
        tmp13 = _tmp12 + tmp11
        _tmp12 = tl.where(r0_mask, tmp13, _tmp12)
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp8, r0_mask)
    tmp12 = tl.sum(_tmp12, 1)[:, None]
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_0 = r0_index
        tmp14 = tl.load(in_out_ptr0 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp23 = tl.load(in_ptr4 + (r0_0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp15 = tmp14.to(tl.float32)
        tmp16 = tl.full([1, 1], 2048.0, tl.float32)
        tmp17 = (tmp12 / tmp16)
        tmp18 = tl.full([1, 1], 1e-05, tl.float32)
        tmp19 = tmp17 + tmp18
        tmp20 = libdevice.rsqrt(tmp19)
        tmp21 = tmp15 * tmp20
        tmp22 = tmp21.to(tl.float32)
        tmp24 = tmp22 * tmp23
        tl.store(in_out_ptr0 + (tl.broadcast_to(r0_0, [XBLOCK, R0_BLOCK])), tmp24, r0_mask)
''', device_str='cuda')

def partition_0(args):
    arg3_1, arg2_1, arg4_1, arg5_1, arg1_1, arg7_1, arg6_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, s27 = args
    args.clear()
    s27 = s27
    assert_size_stride(arg3_1, (1, 2048), (2048, 1))
    assert_size_stride(arg2_1, (2048, ), (1, ))
    assert_size_stride(arg4_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg5_1, (256, 2048), (2048, 1))
    assert_size_stride(arg1_1, (2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 2097152, 1048576, 128, 1))
    assert_size_stride(arg7_1, (1, ), (1, ))
    assert_size_stride(arg6_1, (256, 2048), (2048, 1))
    assert_size_stride(arg8_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg9_1, (2048, ), (1, ))
    assert_size_stride(arg10_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg11_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg12_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg13_1, (2048, ), (1, ))
    assert_size_stride(arg14_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg15_1, (256, 2048), (2048, 1))
    assert_size_stride(arg16_1, (256, 2048), (2048, 1))
    assert_size_stride(arg17_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg18_1, (2048, ), (1, ))
    assert_size_stride(arg19_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg20_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg21_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg22_1, (2048, ), (1, ))
    assert_size_stride(arg23_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg24_1, (256, 2048), (2048, 1))
    assert_size_stride(arg25_1, (256, 2048), (2048, 1))
    assert_size_stride(arg26_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg27_1, (2048, ), (1, ))
    assert_size_stride(arg28_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg29_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg30_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg31_1, (2048, ), (1, ))
    assert_size_stride(arg32_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg33_1, (256, 2048), (2048, 1))
    assert_size_stride(arg34_1, (256, 2048), (2048, 1))
    assert_size_stride(arg35_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg36_1, (2048, ), (1, ))
    assert_size_stride(arg37_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg38_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg39_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg40_1, (2048, ), (1, ))
    assert_size_stride(arg41_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg42_1, (256, 2048), (2048, 1))
    assert_size_stride(arg43_1, (256, 2048), (2048, 1))
    assert_size_stride(arg44_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg45_1, (2048, ), (1, ))
    assert_size_stride(arg46_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg47_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg48_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg49_1, (2048, ), (1, ))
    assert_size_stride(arg50_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg51_1, (256, 2048), (2048, 1))
    assert_size_stride(arg52_1, (256, 2048), (2048, 1))
    assert_size_stride(arg53_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg54_1, (2048, ), (1, ))
    assert_size_stride(arg55_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg56_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg57_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg58_1, (2048, ), (1, ))
    assert_size_stride(arg59_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg60_1, (256, 2048), (2048, 1))
    assert_size_stride(arg61_1, (256, 2048), (2048, 1))
    assert_size_stride(arg62_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg63_1, (2048, ), (1, ))
    assert_size_stride(arg64_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg65_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg66_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg67_1, (2048, ), (1, ))
    assert_size_stride(arg68_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg69_1, (256, 2048), (2048, 1))
    assert_size_stride(arg70_1, (256, 2048), (2048, 1))
    assert_size_stride(arg71_1, (2048, 2048), (2048, 1))
    assert_size_stride(arg72_1, (2048, ), (1, ))
    assert_size_stride(arg73_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg74_1, (6144, 2048), (2048, 1))
    assert_size_stride(arg75_1, (2048, 6144), (6144, 1))
    assert_size_stride(arg76_1, (2048, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf1 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to, pow_1, variance, add, rsqrt, mul, hidden, hidden_states], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0.run(arg3_1, arg2_1, buf1, 1, 2048, stream=stream0)
        del arg2_1
        buf2 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf1, reinterpret_tensor(arg4_1, (2048, 2048), (1, 2048), 0), out=buf2)
        del arg4_1
        buf3 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf1, reinterpret_tensor(arg5_1, (2048, 256), (1, 2048), 0), out=buf3)
        del arg5_1
        buf4 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_cache, view_1, key_states_1, setitem], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_1.run(arg1_1, buf4, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [key_cache, view_1, key_states_1, setitem], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf3, buf4, 256, stream=stream0)
        buf6 = buf3; del buf3  # reuse
        # Topologically Sorted Source Nodes: [value_states], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf1, reinterpret_tensor(arg6_1, (2048, 256), (1, 2048), 0), out=buf6)
        del arg6_1
        buf7 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_1, view_2, value_states_1], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_3.run(buf4, arg1_1, buf7, s27, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_1, view_2, value_states_1], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf6, buf7, 256, stream=stream0)
        ps0 = 2097152*s27
        buf9 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        triton_poi_fused_4_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_4.run(buf7, buf4, arg1_1, buf9, ps0, s27, triton_poi_fused_4_xnumel, stream=stream0)
        del buf4
        buf10 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view, query_states_1, attn_output], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf2, buf10, 2048, stream=stream0)
        buf11 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_6.run(buf9, buf11, 2048, 8192, stream=stream0)
        buf12 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view, query_states_1, attn_output], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf10, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf11, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf12)
        buf16 = reinterpret_tensor(buf12, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf12  # reuse
        # Topologically Sorted Source Nodes: [attn_output, arange, attn_mask], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf16, arg7_1, 16, 8192, stream=stream0)
        buf17 = reinterpret_tensor(buf11, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf11  # reuse
        # Topologically Sorted Source Nodes: [setitem_1, attn_output], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_8.run(buf9, buf17, ps0, 16777216, stream=stream0)
        buf18 = reinterpret_tensor(buf10, (16, 1, 128), (128, 128, 1), 0); del buf10  # reuse
        # Topologically Sorted Source Nodes: [attn_output, arange, attn_mask, setitem_1], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf16, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf17, (16, 8192, 128), (1048576, 128, 1), 0), out=buf18)
        del buf16
        del buf17
        buf19 = reinterpret_tensor(buf2, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf2  # reuse
        # Topologically Sorted Source Nodes: [attn_output], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf18, buf19, 2048, stream=stream0)
        del buf18
        buf20 = buf1; del buf1  # reuse
        # Topologically Sorted Source Nodes: [attn_output, transpose_3, attn_output_2, attn_output_3], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf19, (1, 2048), (0, 1), 0), reinterpret_tensor(arg8_1, (2048, 2048), (1, 2048), 0), out=buf20)
        del arg8_1
        buf22 = reinterpret_tensor(buf19, (1, 2048), (2048, 1), 0); del buf19  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_1, to_2, pow_2, variance_1, add_2, rsqrt_1, mul_2, hidden_1, hidden_states_2], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(arg3_1, buf20, arg9_1, buf22, 1, 2048, stream=stream0)
        del arg9_1
        buf23 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_4], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf22, reinterpret_tensor(arg10_1, (2048, 6144), (1, 2048), 0), out=buf23)
        del arg10_1
        buf24 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_5], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf22, reinterpret_tensor(arg11_1, (2048, 6144), (1, 2048), 0), out=buf24)
        del arg11_1
        buf25 = buf23; del buf23  # reuse
        # Topologically Sorted Source Nodes: [silu, mul_4], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf25, buf24, 6144, stream=stream0)
        del buf24
        buf26 = buf22; del buf22  # reuse
        # Topologically Sorted Source Nodes: [silu, mul_4, hidden_states_3], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf25, reinterpret_tensor(arg12_1, (6144, 2048), (1, 6144), 0), out=buf26)
        del arg12_1
        del buf25
        buf28 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, to_4, pow_3, variance_2, add_4, rsqrt_2, mul_5, hidden_2, hidden_states_5], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(arg3_1, buf20, buf26, arg13_1, buf28, 1, 2048, stream=stream0)
        del arg13_1
        buf29 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_3], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf28, reinterpret_tensor(arg14_1, (2048, 2048), (1, 2048), 0), out=buf29)
        del arg14_1
        buf30 = buf6; del buf6  # reuse
        # Topologically Sorted Source Nodes: [key_states_2], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf28, reinterpret_tensor(arg15_1, (2048, 256), (1, 2048), 0), out=buf30)
        del arg15_1
        buf31 = buf7; del buf7  # reuse
        # Topologically Sorted Source Nodes: [setitem_2, view_4, key_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_13.run(buf9, buf31, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_2, view_4, key_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf30, buf31, 256, stream=stream0)
        buf33 = buf30; del buf30  # reuse
        # Topologically Sorted Source Nodes: [value_states_2], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf28, reinterpret_tensor(arg16_1, (2048, 256), (1, 2048), 0), out=buf33)
        del arg16_1
        buf34 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_3, view_5, value_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_14.run(buf31, buf9, buf34, ps0, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_3, view_5, value_states_3], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf33, buf34, 256, stream=stream0)
        del buf33
        buf36 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        triton_poi_fused_15_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_15.run(buf34, buf31, buf9, buf36, ps0, s27, triton_poi_fused_15_xnumel, stream=stream0)
        del buf31
        del buf9
        buf37 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_3, query_states_4, attn_output_4], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf29, buf37, 2048, stream=stream0)
        buf38 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_16.run(buf36, buf38, 2048, 8192, stream=stream0)
        buf39 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_3, query_states_4, attn_output_4], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf37, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf38, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf39)
        buf43 = reinterpret_tensor(buf39, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf39  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4, arange_1, attn_mask_1], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf43, arg7_1, 16, 8192, stream=stream0)
        buf44 = reinterpret_tensor(buf38, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf38  # reuse
        # Topologically Sorted Source Nodes: [setitem_3, attn_output_4], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_17.run(buf36, buf44, ps0, 16777216, stream=stream0)
        buf45 = reinterpret_tensor(buf37, (16, 1, 128), (128, 128, 1), 0); del buf37  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4, arange_1, attn_mask_1, setitem_3], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf43, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf44, (16, 8192, 128), (1048576, 128, 1), 0), out=buf45)
        del buf43
        del buf44
        buf46 = reinterpret_tensor(buf29, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf29  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf45, buf46, 2048, stream=stream0)
        buf47 = buf28; del buf28  # reuse
        # Topologically Sorted Source Nodes: [attn_output_4, transpose_7, attn_output_6, attn_output_7], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf46, (1, 2048), (0, 1), 0), reinterpret_tensor(arg17_1, (2048, 2048), (1, 2048), 0), out=buf47)
        del arg17_1
        buf49 = reinterpret_tensor(buf46, (1, 2048), (2048, 1), 0); del buf46  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, to_6, pow_4, variance_3, add_6, rsqrt_3, mul_7, hidden_3, hidden_states_7], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18.run(arg3_1, buf20, buf26, buf47, arg18_1, buf49, 1, 2048, stream=stream0)
        del arg18_1
        buf50 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_11], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf49, reinterpret_tensor(arg19_1, (2048, 6144), (1, 2048), 0), out=buf50)
        del arg19_1
        buf51 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf49, reinterpret_tensor(arg20_1, (2048, 6144), (1, 2048), 0), out=buf51)
        del arg20_1
        buf52 = buf50; del buf50  # reuse
        # Topologically Sorted Source Nodes: [silu_1, mul_9], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf52, buf51, 6144, stream=stream0)
        del buf51
        buf53 = buf49; del buf49  # reuse
        # Topologically Sorted Source Nodes: [silu_1, mul_9, hidden_states_8], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf52, reinterpret_tensor(arg21_1, (6144, 2048), (1, 6144), 0), out=buf53)
        del arg21_1
        del buf52
        buf54 = buf20; del buf20  # reuse
        buf56 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_1, hidden_states_4, hidden_states_6, hidden_states_9, to_8, pow_5, variance_4, add_8, rsqrt_4, mul_10, hidden_4, hidden_states_10], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_19.run(buf54, arg3_1, buf26, buf47, buf53, arg22_1, buf56, 1, 2048, stream=stream0)
        del arg22_1
        del arg3_1
        del buf26
        del buf47
        buf57 = buf53; del buf53  # reuse
        # Topologically Sorted Source Nodes: [query_states_6], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf56, reinterpret_tensor(arg23_1, (2048, 2048), (1, 2048), 0), out=buf57)
        del arg23_1
        buf58 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_4], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf56, reinterpret_tensor(arg24_1, (2048, 256), (1, 2048), 0), out=buf58)
        del arg24_1
        buf59 = buf34; del buf34  # reuse
        # Topologically Sorted Source Nodes: [setitem_4, view_7, key_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_20.run(buf36, buf59, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_4, view_7, key_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf58, buf59, 256, stream=stream0)
        buf61 = buf58; del buf58  # reuse
        # Topologically Sorted Source Nodes: [value_states_4], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf56, reinterpret_tensor(arg25_1, (2048, 256), (1, 2048), 0), out=buf61)
        del arg25_1
        del buf56
        buf62 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_5, view_8, value_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_21.run(buf59, buf36, buf62, ps0, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_5, view_8, value_states_5], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf61, buf62, 256, stream=stream0)
        del buf61
        buf64 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        triton_poi_fused_22_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_22.run(buf62, buf59, buf36, buf64, ps0, s27, triton_poi_fused_22_xnumel, stream=stream0)
        del buf36
        del buf59
        buf65 = reinterpret_tensor(buf45, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf45  # reuse
        # Topologically Sorted Source Nodes: [view_6, query_states_7, attn_output_8], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf57, buf65, 2048, stream=stream0)
        buf66 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_23.run(buf64, buf66, 2048, 8192, stream=stream0)
        buf67 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_6, query_states_7, attn_output_8], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf65, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf66, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf67)
        buf71 = reinterpret_tensor(buf67, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf67  # reuse
        # Topologically Sorted Source Nodes: [attn_output_8, arange_2, attn_mask_2], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf71, arg7_1, 16, 8192, stream=stream0)
        buf72 = reinterpret_tensor(buf66, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf66  # reuse
        # Topologically Sorted Source Nodes: [setitem_5, attn_output_8], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_24.run(buf64, buf72, ps0, 16777216, stream=stream0)
        buf73 = reinterpret_tensor(buf65, (16, 1, 128), (128, 128, 1), 0); del buf65  # reuse
        # Topologically Sorted Source Nodes: [attn_output_8, arange_2, attn_mask_2, setitem_5], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf71, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf72, (16, 8192, 128), (1048576, 128, 1), 0), out=buf73)
        del buf71
        del buf72
        buf74 = reinterpret_tensor(buf57, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf57  # reuse
        # Topologically Sorted Source Nodes: [attn_output_8], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf73, buf74, 2048, stream=stream0)
        del buf73
        buf75 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_8, transpose_11, attn_output_10, attn_output_11], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf74, (1, 2048), (0, 1), 0), reinterpret_tensor(arg26_1, (2048, 2048), (1, 2048), 0), out=buf75)
        del arg26_1
        buf77 = reinterpret_tensor(buf74, (1, 2048), (2048, 1), 0); del buf74  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_11, to_10, pow_6, variance_5, add_10, rsqrt_5, mul_12, hidden_5, hidden_states_12], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf54, buf75, arg27_1, buf77, 1, 2048, stream=stream0)
        del arg27_1
        buf78 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_18], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf77, reinterpret_tensor(arg28_1, (2048, 6144), (1, 2048), 0), out=buf78)
        del arg28_1
        buf79 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_19], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf77, reinterpret_tensor(arg29_1, (2048, 6144), (1, 2048), 0), out=buf79)
        del arg29_1
        buf80 = buf78; del buf78  # reuse
        # Topologically Sorted Source Nodes: [silu_2, mul_14], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf80, buf79, 6144, stream=stream0)
        del buf79
        buf81 = buf77; del buf77  # reuse
        # Topologically Sorted Source Nodes: [silu_2, mul_14, hidden_states_13], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf80, reinterpret_tensor(arg30_1, (6144, 2048), (1, 6144), 0), out=buf81)
        del arg30_1
        del buf80
        buf83 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, to_12, pow_7, variance_6, add_12, rsqrt_6, mul_15, hidden_6, hidden_states_15], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf54, buf75, buf81, arg31_1, buf83, 1, 2048, stream=stream0)
        del arg31_1
        buf84 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_9], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf83, reinterpret_tensor(arg32_1, (2048, 2048), (1, 2048), 0), out=buf84)
        del arg32_1
        buf85 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_6], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf83, reinterpret_tensor(arg33_1, (2048, 256), (1, 2048), 0), out=buf85)
        del arg33_1
        buf86 = buf62; del buf62  # reuse
        # Topologically Sorted Source Nodes: [setitem_6, view_10, key_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_25.run(buf64, buf86, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_6, view_10, key_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf85, buf86, 256, stream=stream0)
        buf88 = buf85; del buf85  # reuse
        # Topologically Sorted Source Nodes: [value_states_6], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf83, reinterpret_tensor(arg34_1, (2048, 256), (1, 2048), 0), out=buf88)
        del arg34_1
        del buf83
        buf89 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_7, view_11, value_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_26.run(buf86, buf64, buf89, ps0, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_7, view_11, value_states_7], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf88, buf89, 256, stream=stream0)
        del buf88
        buf91 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        triton_poi_fused_27_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_27.run(buf89, buf86, buf64, buf91, ps0, s27, triton_poi_fused_27_xnumel, stream=stream0)
        del buf64
        del buf86
        buf92 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_9, query_states_10, attn_output_12], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf84, buf92, 2048, stream=stream0)
        buf93 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_28.run(buf91, buf93, 2048, 8192, stream=stream0)
        buf94 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_9, query_states_10, attn_output_12], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf92, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf93, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf94)
        buf98 = reinterpret_tensor(buf94, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf94  # reuse
        # Topologically Sorted Source Nodes: [attn_output_12, arange_3, attn_mask_3], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf98, arg7_1, 16, 8192, stream=stream0)
        buf99 = reinterpret_tensor(buf93, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf93  # reuse
        # Topologically Sorted Source Nodes: [setitem_7, attn_output_12], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_29.run(buf91, buf99, ps0, 16777216, stream=stream0)
        buf100 = reinterpret_tensor(buf92, (16, 1, 128), (128, 128, 1), 0); del buf92  # reuse
        # Topologically Sorted Source Nodes: [attn_output_12, arange_3, attn_mask_3, setitem_7], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf98, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf99, (16, 8192, 128), (1048576, 128, 1), 0), out=buf100)
        del buf98
        del buf99
        buf101 = reinterpret_tensor(buf84, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf84  # reuse
        # Topologically Sorted Source Nodes: [attn_output_12], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf100, buf101, 2048, stream=stream0)
        buf102 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_12, transpose_15, attn_output_14, attn_output_15], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf101, (1, 2048), (0, 1), 0), reinterpret_tensor(arg35_1, (2048, 2048), (1, 2048), 0), out=buf102)
        del arg35_1
        buf104 = reinterpret_tensor(buf101, (1, 2048), (2048, 1), 0); del buf101  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, hidden_states_16, to_14, pow_8, variance_7, add_14, rsqrt_7, mul_17, hidden_7, hidden_states_17], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18.run(buf54, buf75, buf81, buf102, arg36_1, buf104, 1, 2048, stream=stream0)
        del arg36_1
        buf105 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_25], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf104, reinterpret_tensor(arg37_1, (2048, 6144), (1, 2048), 0), out=buf105)
        del arg37_1
        buf106 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_26], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf104, reinterpret_tensor(arg38_1, (2048, 6144), (1, 2048), 0), out=buf106)
        del arg38_1
        buf107 = buf105; del buf105  # reuse
        # Topologically Sorted Source Nodes: [silu_3, mul_19], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf107, buf106, 6144, stream=stream0)
        del buf106
        buf108 = buf104; del buf104  # reuse
        # Topologically Sorted Source Nodes: [silu_3, mul_19, hidden_states_18], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf107, reinterpret_tensor(arg39_1, (6144, 2048), (1, 6144), 0), out=buf108)
        del arg39_1
        del buf107
        buf109 = buf54; del buf54  # reuse
        buf111 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_11, hidden_states_14, hidden_states_16, hidden_states_19, to_16, pow_9, variance_8, add_16, rsqrt_8, mul_20, hidden_8, hidden_states_20], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_30.run(buf109, buf75, buf81, buf102, buf108, arg40_1, buf111, 1, 2048, stream=stream0)
        del arg40_1
        del buf102
        del buf108
        del buf75
        buf112 = buf81; del buf81  # reuse
        # Topologically Sorted Source Nodes: [query_states_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf111, reinterpret_tensor(arg41_1, (2048, 2048), (1, 2048), 0), out=buf112)
        del arg41_1
        buf113 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_8], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf111, reinterpret_tensor(arg42_1, (2048, 256), (1, 2048), 0), out=buf113)
        del arg42_1
        buf114 = buf89; del buf89  # reuse
        # Topologically Sorted Source Nodes: [setitem_8, view_13, key_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_31.run(buf91, buf114, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_8, view_13, key_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf113, buf114, 256, stream=stream0)
        buf116 = buf113; del buf113  # reuse
        # Topologically Sorted Source Nodes: [value_states_8], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf111, reinterpret_tensor(arg43_1, (2048, 256), (1, 2048), 0), out=buf116)
        del arg43_1
        del buf111
        buf117 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_9, view_14, value_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_32.run(buf114, buf91, buf117, ps0, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_9, view_14, value_states_9], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf116, buf117, 256, stream=stream0)
        del buf116
        buf119 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        triton_poi_fused_33_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_33.run(buf117, buf114, buf91, buf119, ps0, s27, triton_poi_fused_33_xnumel, stream=stream0)
        del buf114
        del buf91
        buf120 = reinterpret_tensor(buf100, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf100  # reuse
        # Topologically Sorted Source Nodes: [view_12, query_states_13, attn_output_16], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf112, buf120, 2048, stream=stream0)
        buf121 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_34.run(buf119, buf121, 2048, 8192, stream=stream0)
        buf122 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_12, query_states_13, attn_output_16], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf120, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf121, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf122)
        buf126 = reinterpret_tensor(buf122, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf122  # reuse
        # Topologically Sorted Source Nodes: [attn_output_16, arange_4, attn_mask_4], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf126, arg7_1, 16, 8192, stream=stream0)
        buf127 = reinterpret_tensor(buf121, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf121  # reuse
        # Topologically Sorted Source Nodes: [setitem_9, attn_output_16], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_35.run(buf119, buf127, ps0, 16777216, stream=stream0)
        buf128 = reinterpret_tensor(buf120, (16, 1, 128), (128, 128, 1), 0); del buf120  # reuse
        # Topologically Sorted Source Nodes: [attn_output_16, arange_4, attn_mask_4, setitem_9], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf126, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf127, (16, 8192, 128), (1048576, 128, 1), 0), out=buf128)
        del buf126
        del buf127
        buf129 = reinterpret_tensor(buf112, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf112  # reuse
        # Topologically Sorted Source Nodes: [attn_output_16], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf128, buf129, 2048, stream=stream0)
        del buf128
        buf130 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_16, transpose_19, attn_output_18, attn_output_19], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf129, (1, 2048), (0, 1), 0), reinterpret_tensor(arg44_1, (2048, 2048), (1, 2048), 0), out=buf130)
        del arg44_1
        buf132 = reinterpret_tensor(buf129, (1, 2048), (2048, 1), 0); del buf129  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_21, to_18, pow_10, variance_9, add_18, rsqrt_9, mul_22, hidden_9, hidden_states_22], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf109, buf130, arg45_1, buf132, 1, 2048, stream=stream0)
        del arg45_1
        buf133 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_32], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf132, reinterpret_tensor(arg46_1, (2048, 6144), (1, 2048), 0), out=buf133)
        del arg46_1
        buf134 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_33], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf132, reinterpret_tensor(arg47_1, (2048, 6144), (1, 2048), 0), out=buf134)
        del arg47_1
        buf135 = buf133; del buf133  # reuse
        # Topologically Sorted Source Nodes: [silu_4, mul_24], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf135, buf134, 6144, stream=stream0)
        del buf134
        buf136 = buf132; del buf132  # reuse
        # Topologically Sorted Source Nodes: [silu_4, mul_24, hidden_states_23], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf135, reinterpret_tensor(arg48_1, (6144, 2048), (1, 6144), 0), out=buf136)
        del arg48_1
        del buf135
        buf138 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_21, hidden_states_24, to_20, pow_11, variance_10, add_20, rsqrt_10, mul_25, hidden_10, hidden_states_25], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf109, buf130, buf136, arg49_1, buf138, 1, 2048, stream=stream0)
        del arg49_1
        buf139 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_15], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf138, reinterpret_tensor(arg50_1, (2048, 2048), (1, 2048), 0), out=buf139)
        del arg50_1
        buf140 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_10], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf138, reinterpret_tensor(arg51_1, (2048, 256), (1, 2048), 0), out=buf140)
        del arg51_1
        buf141 = buf117; del buf117  # reuse
        # Topologically Sorted Source Nodes: [setitem_10, view_16, key_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_36.run(buf119, buf141, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_10, view_16, key_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf140, buf141, 256, stream=stream0)
        buf143 = buf140; del buf140  # reuse
        # Topologically Sorted Source Nodes: [value_states_10], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf138, reinterpret_tensor(arg52_1, (2048, 256), (1, 2048), 0), out=buf143)
        del arg52_1
        del buf138
        buf144 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_11, view_17, value_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_37.run(buf141, buf119, buf144, ps0, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_11, view_17, value_states_11], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf143, buf144, 256, stream=stream0)
        del buf143
        buf146 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        triton_poi_fused_38_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_38.run(buf144, buf141, buf119, buf146, ps0, s27, triton_poi_fused_38_xnumel, stream=stream0)
        del buf119
        del buf141
        buf147 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_15, query_states_16, attn_output_20], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf139, buf147, 2048, stream=stream0)
        buf148 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_39.run(buf146, buf148, 2048, 8192, stream=stream0)
        buf149 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_15, query_states_16, attn_output_20], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf147, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf148, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf149)
        buf153 = reinterpret_tensor(buf149, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf149  # reuse
        # Topologically Sorted Source Nodes: [attn_output_20, arange_5, attn_mask_5], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf153, arg7_1, 16, 8192, stream=stream0)
        buf154 = reinterpret_tensor(buf148, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf148  # reuse
        # Topologically Sorted Source Nodes: [setitem_11, attn_output_20], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_40.run(buf146, buf154, ps0, 16777216, stream=stream0)
        buf155 = reinterpret_tensor(buf147, (16, 1, 128), (128, 128, 1), 0); del buf147  # reuse
        # Topologically Sorted Source Nodes: [attn_output_20, arange_5, attn_mask_5, setitem_11], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf153, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf154, (16, 8192, 128), (1048576, 128, 1), 0), out=buf155)
        del buf153
        del buf154
        buf156 = reinterpret_tensor(buf139, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf139  # reuse
        # Topologically Sorted Source Nodes: [attn_output_20], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf155, buf156, 2048, stream=stream0)
        buf157 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_20, transpose_23, attn_output_22, attn_output_23], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf156, (1, 2048), (0, 1), 0), reinterpret_tensor(arg53_1, (2048, 2048), (1, 2048), 0), out=buf157)
        del arg53_1
        buf159 = reinterpret_tensor(buf156, (1, 2048), (2048, 1), 0); del buf156  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_21, hidden_states_24, hidden_states_26, to_22, pow_12, variance_11, add_22, rsqrt_11, mul_27, hidden_11, hidden_states_27], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18.run(buf109, buf130, buf136, buf157, arg54_1, buf159, 1, 2048, stream=stream0)
        del arg54_1
        buf160 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_39], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf159, reinterpret_tensor(arg55_1, (2048, 6144), (1, 2048), 0), out=buf160)
        del arg55_1
        buf161 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_40], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf159, reinterpret_tensor(arg56_1, (2048, 6144), (1, 2048), 0), out=buf161)
        del arg56_1
        buf162 = buf160; del buf160  # reuse
        # Topologically Sorted Source Nodes: [silu_5, mul_29], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf162, buf161, 6144, stream=stream0)
        del buf161
        buf163 = buf159; del buf159  # reuse
        # Topologically Sorted Source Nodes: [silu_5, mul_29, hidden_states_28], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf162, reinterpret_tensor(arg57_1, (6144, 2048), (1, 6144), 0), out=buf163)
        del arg57_1
        del buf162
        buf164 = buf109; del buf109  # reuse
        buf166 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_21, hidden_states_24, hidden_states_26, hidden_states_29, to_24, pow_13, variance_12, add_24, rsqrt_12, mul_30, hidden_12, hidden_states_30], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_30.run(buf164, buf130, buf136, buf157, buf163, arg58_1, buf166, 1, 2048, stream=stream0)
        del arg58_1
        del buf130
        del buf136
        del buf157
        buf167 = buf163; del buf163  # reuse
        # Topologically Sorted Source Nodes: [query_states_18], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf166, reinterpret_tensor(arg59_1, (2048, 2048), (1, 2048), 0), out=buf167)
        del arg59_1
        buf168 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf166, reinterpret_tensor(arg60_1, (2048, 256), (1, 2048), 0), out=buf168)
        del arg60_1
        buf169 = buf144; del buf144  # reuse
        # Topologically Sorted Source Nodes: [setitem_12, view_19, key_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_41.run(buf146, buf169, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_12, view_19, key_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf168, buf169, 256, stream=stream0)
        buf171 = buf168; del buf168  # reuse
        # Topologically Sorted Source Nodes: [value_states_12], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf166, reinterpret_tensor(arg61_1, (2048, 256), (1, 2048), 0), out=buf171)
        del arg61_1
        del buf166
        buf172 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_13, view_20, value_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_42.run(buf169, buf146, buf172, ps0, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_13, view_20, value_states_13], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf171, buf172, 256, stream=stream0)
        del buf171
        buf174 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: []
        triton_poi_fused_43_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_43.run(buf172, buf169, buf146, buf174, ps0, s27, triton_poi_fused_43_xnumel, stream=stream0)
        del buf146
        del buf169
        buf175 = reinterpret_tensor(buf155, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf155  # reuse
        # Topologically Sorted Source Nodes: [view_18, query_states_19, attn_output_24], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf167, buf175, 2048, stream=stream0)
        buf176 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_44.run(buf174, buf176, 2048, 8192, stream=stream0)
        buf177 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_18, query_states_19, attn_output_24], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf175, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf176, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf177)
        buf181 = reinterpret_tensor(buf177, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf177  # reuse
        # Topologically Sorted Source Nodes: [attn_output_24, arange_6, attn_mask_6], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf181, arg7_1, 16, 8192, stream=stream0)
        buf182 = reinterpret_tensor(buf176, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf176  # reuse
        # Topologically Sorted Source Nodes: [setitem_13, attn_output_24], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_45.run(buf174, buf182, ps0, 16777216, stream=stream0)
        buf183 = reinterpret_tensor(buf175, (16, 1, 128), (128, 128, 1), 0); del buf175  # reuse
        # Topologically Sorted Source Nodes: [attn_output_24, arange_6, attn_mask_6, setitem_13], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf181, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf182, (16, 8192, 128), (1048576, 128, 1), 0), out=buf183)
        del buf181
        del buf182
        buf184 = reinterpret_tensor(buf167, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf167  # reuse
        # Topologically Sorted Source Nodes: [attn_output_24], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf183, buf184, 2048, stream=stream0)
        del buf183
        buf185 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_24, transpose_27, attn_output_26, attn_output_27], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf184, (1, 2048), (0, 1), 0), reinterpret_tensor(arg62_1, (2048, 2048), (1, 2048), 0), out=buf185)
        del arg62_1
        buf187 = reinterpret_tensor(buf184, (1, 2048), (2048, 1), 0); del buf184  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_31, to_26, pow_14, variance_13, add_26, rsqrt_13, mul_32, hidden_13, hidden_states_32], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_10.run(buf164, buf185, arg63_1, buf187, 1, 2048, stream=stream0)
        del arg63_1
        buf188 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_46], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf187, reinterpret_tensor(arg64_1, (2048, 6144), (1, 2048), 0), out=buf188)
        del arg64_1
        buf189 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_47], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf187, reinterpret_tensor(arg65_1, (2048, 6144), (1, 2048), 0), out=buf189)
        del arg65_1
        buf190 = buf188; del buf188  # reuse
        # Topologically Sorted Source Nodes: [silu_6, mul_34], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf190, buf189, 6144, stream=stream0)
        del buf189
        buf191 = buf187; del buf187  # reuse
        # Topologically Sorted Source Nodes: [silu_6, mul_34, hidden_states_33], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf190, reinterpret_tensor(arg66_1, (6144, 2048), (1, 6144), 0), out=buf191)
        del arg66_1
        del buf190
        buf193 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_states_31, hidden_states_34, to_28, pow_15, variance_14, add_28, rsqrt_14, mul_35, hidden_14, hidden_states_35], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_12.run(buf164, buf185, buf191, arg67_1, buf193, 1, 2048, stream=stream0)
        del arg67_1
        buf194 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_21], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf193, reinterpret_tensor(arg68_1, (2048, 2048), (1, 2048), 0), out=buf194)
        del arg68_1
        buf195 = empty_strided_cuda((1, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_14], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf193, reinterpret_tensor(arg69_1, (2048, 256), (1, 2048), 0), out=buf195)
        del arg69_1
        buf196 = buf172; del buf172  # reuse
        # Topologically Sorted Source Nodes: [setitem_14, view_22, key_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_46.run(buf174, buf196, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_14, view_22, key_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf195, buf196, 256, stream=stream0)
        buf198 = buf195; del buf195  # reuse
        # Topologically Sorted Source Nodes: [value_states_14], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf193, reinterpret_tensor(arg70_1, (2048, 256), (1, 2048), 0), out=buf198)
        del arg70_1
        del buf193
        buf199 = empty_strided_cuda((1, 2, 8192, 128), (2097152, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [setitem_15, view_23, value_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_47.run(buf196, buf174, buf199, ps0, 2097152, stream=stream0)
        # Topologically Sorted Source Nodes: [setitem_15, view_23, value_states_15], Original ATen: [aten.select, aten.view, aten.transpose, aten.index_put]
        stream0 = get_raw_stream(0)
        triton_poi_fused_index_put_select_transpose_view_2.run(arg7_1, buf198, buf199, 256, stream=stream0)
        del buf198
        buf201 = empty_strided_cuda((2, s27, 1, 2, 8192, 128), (2097152*s27, 2097152, 4194304*s27, 1048576, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [], Original ATen: [aten.copy_]
        triton_poi_fused_copy__48_xnumel = 4194304*s27
        stream0 = get_raw_stream(0)
        triton_poi_fused_copy__48.run(buf199, buf196, buf174, buf201, arg1_1, ps0, s27, triton_poi_fused_copy__48_xnumel, stream=stream0)
        del arg1_1
        del buf174
        del buf196
        del buf199
        buf202 = empty_strided_cuda((1, 16, 1, 128), (2048, 128, 128, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_21, query_states_22, attn_output_28], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_mul_transpose_view_5.run(buf194, buf202, 2048, stream=stream0)
        buf203 = empty_strided_cuda((1, 16, 128, 8192), (16777216, 1048576, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.view, aten.transpose, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_mul_select_transpose_unsqueeze_view_49.run(buf201, buf203, 2048, 8192, stream=stream0)
        buf204 = empty_strided_cuda((16, 1, 8192), (8192, 8192, 1), torch.float32)
        # Topologically Sorted Source Nodes: [view_21, query_states_22, attn_output_28], Original ATen: [aten.view, aten.transpose, aten._to_copy, aten.mul, aten.select, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf202, (16, 1, 128), (128, 0, 1), 0), reinterpret_tensor(buf203, (16, 128, 8192), (1048576, 8192, 1), 0), out=buf204)
        buf208 = reinterpret_tensor(buf204, (1, 16, 1, 8192), (131072, 8192, 8192, 1), 0); del buf204  # reuse
        # Topologically Sorted Source Nodes: [attn_output_28, arange_7, attn_mask_7], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, prims.prepare_softmax_online, aten.sub, aten.exp]
        stream0 = get_raw_stream(0)
        triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7.run(buf208, arg7_1, 16, 8192, stream=stream0)
        del arg7_1
        buf209 = reinterpret_tensor(buf203, (1, 2, 8, 8192, 128), (16777216, 8388608, 1048576, 128, 1), 0); del buf203  # reuse
        # Topologically Sorted Source Nodes: [setitem_15, attn_output_28], Original ATen: [aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_clone_expand_select_unsqueeze_50.run(buf201, buf209, ps0, 16777216, stream=stream0)
        del buf201
        buf210 = reinterpret_tensor(buf202, (16, 1, 128), (128, 128, 1), 0); del buf202  # reuse
        # Topologically Sorted Source Nodes: [attn_output_28, arange_7, attn_mask_7, setitem_15], Original ATen: [aten.view, aten.arange, aten.le, aten.scalar_tensor, aten.where, aten.add, aten._safe_softmax, aten.sub, aten.exp, aten.select, aten._to_copy, aten.unsqueeze, aten.expand, aten.clone, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf208, (16, 1, 8192), (8192, 0, 1), 0), reinterpret_tensor(buf209, (16, 8192, 128), (1048576, 128, 1), 0), out=buf210)
        del buf208
        del buf209
        buf211 = reinterpret_tensor(buf194, (1, 16, 1, 128), (2048, 128, 128, 1), 0); del buf194  # reuse
        # Topologically Sorted Source Nodes: [attn_output_28], Original ATen: [aten.view, aten._to_copy]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_view_9.run(buf210, buf211, 2048, stream=stream0)
        del buf210
        buf212 = empty_strided_cuda((1, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_28, transpose_31, attn_output_30, attn_output_31], Original ATen: [aten.view, aten._to_copy, aten.transpose, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf211, (1, 2048), (0, 1), 0), reinterpret_tensor(arg71_1, (2048, 2048), (1, 2048), 0), out=buf212)
        del arg71_1
        buf214 = reinterpret_tensor(buf211, (1, 2048), (2048, 1), 0); del buf211  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_31, hidden_states_34, hidden_states_36, to_30, pow_16, variance_15, add_30, rsqrt_15, mul_37, hidden_15, hidden_states_37], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_18.run(buf164, buf185, buf191, buf212, arg72_1, buf214, 1, 2048, stream=stream0)
        del arg72_1
        buf215 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_53], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf214, reinterpret_tensor(arg73_1, (2048, 6144), (1, 2048), 0), out=buf215)
        del arg73_1
        buf216 = empty_strided_cuda((1, 6144), (6144, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_54], Original ATen: [aten.t, aten.mm]
        extern_kernels.mm(buf214, reinterpret_tensor(arg74_1, (2048, 6144), (1, 2048), 0), out=buf216)
        del arg74_1
        buf217 = buf215; del buf215  # reuse
        # Topologically Sorted Source Nodes: [silu_7, mul_39], Original ATen: [aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_mul_silu_11.run(buf217, buf216, 6144, stream=stream0)
        del buf216
        buf218 = buf214; del buf214  # reuse
        # Topologically Sorted Source Nodes: [silu_7, mul_39, hidden_states_38], Original ATen: [aten.silu, aten.mul, aten.t, aten.mm]
        extern_kernels.mm(buf217, reinterpret_tensor(arg75_1, (6144, 2048), (1, 6144), 0), out=buf218)
        del arg75_1
        del buf217
        buf219 = buf164; del buf164  # reuse
        buf221 = buf219; del buf219  # reuse
        # Topologically Sorted Source Nodes: [hidden_states_31, hidden_states_34, hidden_states_36, hidden_states_39, to_32, pow_17, variance_16, add_32, rsqrt_16, mul_40, hidden_16, hidden_states_40], Original ATen: [aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_51.run(buf221, buf185, buf191, buf212, buf218, arg76_1, 1, 2048, stream=stream0)
        del arg76_1
        del buf185
        del buf191
        del buf212
        del buf218
    return (buf221, )


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1 = args
        args.clear()
        s27 = arg0_1
        partition0_args = [arg3_1, arg2_1, arg4_1, arg5_1, arg1_1, arg7_1, arg6_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, s27]
        del arg3_1, arg2_1, arg4_1, arg5_1, arg1_1, arg7_1, arg6_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1
        (buf221,) = self.partitions[0](partition0_args)
        del partition0_args
        return (buf221, )

runner = Runner(partitions=[partition_0,])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = 8
    arg1_1 = rand_strided((2, 8, 1, 2, 8192, 128), (16777216, 2097152, 2097152, 1048576, 128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((1, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg4_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg5_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg6_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg7_1 = rand_strided((1, ), (1, ), device='cuda:0', dtype=torch.int64)
    arg8_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg9_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg10_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg11_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg12_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg13_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg14_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg15_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg16_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg17_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg18_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg19_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg20_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg21_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg22_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg23_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg24_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg25_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg26_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg27_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg28_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg29_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg30_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg31_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg32_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg33_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg34_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg35_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg36_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg37_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg38_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg39_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg40_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg41_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg42_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg43_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg44_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg45_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg46_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg47_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg48_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg49_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg50_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg51_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg52_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg53_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg54_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg55_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg56_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg57_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg58_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg59_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg60_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg61_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg62_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg63_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg64_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg65_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg66_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg67_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg68_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg69_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg70_1 = rand_strided((256, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg71_1 = rand_strided((2048, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg72_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg73_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg74_1 = rand_strided((6144, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg75_1 = rand_strided((2048, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
    arg76_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
