# AOT ID: ['1_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
from torch._C import _cuda_getCurrentRawStream as get_raw_stream



# kernel path: /tmp/torchinductor_shingokuga/mu/cmutgtnotiep3chterc2h642e7dfkii746wqglnat2si6v3n3rll.py
# Topologically Sorted Source Nodes: [unsqueeze, mul_1, arange, mul, emb, unsqueeze_1, emb_1, sin, cos, emb_2], Original ATen: [aten.unsqueeze, aten.mul, aten.arange, aten.exp, aten.sin, aten.cos, aten.cat]
# Source node to ATen node mapping:
#   arange => add, convert_element_type_6, iota, mul
#   cos => cos
#   emb => exp
#   emb_1 => mul_3
#   emb_2 => cat
#   mul => mul_1
#   mul_1 => mul_2
#   sin => sin
#   unsqueeze => unsqueeze
#   unsqueeze_1 => unsqueeze_1
# Graph fragment:
#   %arg6_1 : Tensor "bf16[2][1]cuda:0" = PlaceHolder[target=arg6_1]
#   %unsqueeze : Tensor "bf16[2, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg6_1, 1), kwargs = {})
#   %mul_2 : Tensor "bf16[2, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%unsqueeze, 1000), kwargs = {})
#   %iota : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.iota.default](args = (512,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %mul : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%iota, 1), kwargs = {})
#   %add : Tensor "i64[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, 0), kwargs = {})
#   %convert_element_type_6 : Tensor "bf16[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add, torch.bfloat16), kwargs = {})
#   %mul_1 : Tensor "bf16[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_6, -0.01802414945592208), kwargs = {})
#   %exp : Tensor "bf16[512][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%mul_1,), kwargs = {})
#   %unsqueeze_1 : Tensor "bf16[1, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%exp, 0), kwargs = {})
#   %mul_3 : Tensor "bf16[2, 512][512, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_2, %unsqueeze_1), kwargs = {})
#   %sin : Tensor "bf16[2, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mul_3,), kwargs = {})
#   %cos : Tensor "bf16[2, 512][512, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%mul_3,), kwargs = {})
#   %cat : Tensor "bf16[2, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%sin, %cos], -1), kwargs = {})
#   return %cat
triton_poi_fused_arange_cat_cos_exp_mul_sin_unsqueeze_0 = async_compile.triton('triton_poi_fused_arange_cat_cos_exp_mul_sin_unsqueeze_0', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_arange_cat_cos_exp_mul_sin_unsqueeze_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 8192}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_arange_cat_cos_exp_mul_sin_unsqueeze_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = (xindex % 1024)
    x1 = xindex // 1024
    x2 = xindex
    tmp0 = x0
    tmp1 = tl.full([1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 512, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (x1), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp6 = tl.full([1], 1000.0, tl.float32)
    tmp7 = tmp5 * tmp6
    tmp8 = x0
    tmp9 = tmp8.to(tl.float32)
    tmp10 = tl.full([1], -0.01802414945592208, tl.float32)
    tmp11 = tmp9 * tmp10
    tmp12 = libdevice.exp(tmp11)
    tmp13 = tmp7 * tmp12
    tmp14 = tl_math.sin(tmp13)
    tmp15 = tl.full(tmp14.shape, 0.0, tmp14.dtype)
    tmp16 = tl.where(tmp4, tmp14, tmp15)
    tmp17 = tmp0 >= tmp3
    tmp18 = tl.full([1], 1024, tl.int64)
    tmp19 = tmp0 < tmp18
    tmp20 = tl.load(in_ptr0 + (x1), tmp17 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp21 = tl.full([1], 1000.0, tl.float32)
    tmp22 = tmp20 * tmp21
    tmp23 = (-512) + x0
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tl.full([1], -0.01802414945592208, tl.float32)
    tmp26 = tmp24 * tmp25
    tmp27 = libdevice.exp(tmp26)
    tmp28 = tmp22 * tmp27
    tmp29 = tl_math.cos(tmp28)
    tmp30 = tl.full(tmp29.shape, 0.0, tmp29.dtype)
    tmp31 = tl.where(tmp17, tmp29, tmp30)
    tmp32 = tl.where(tmp4, tmp16, tmp31)
    tl.store(out_ptr0 + (x2), tmp32, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/o5/co5p7tt2e6b6rh4jfxckn3hoc7ibl4jteliu66blvqd7ugyuoa7p.py
# Topologically Sorted Source Nodes: [sample_1], Original ATen: [aten.silu]
# Source node to ATen node mapping:
#   sample_1 => add_1, convert_element_type_10, convert_element_type_11, div, exp_1, neg
# Graph fragment:
#   %addmm_2 : Tensor "bf16[2, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm_2]
#   %convert_element_type_10 : Tensor "f32[2, 1024][1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%addmm_2, torch.float32), kwargs = {})
#   %neg : Tensor "f32[2, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%convert_element_type_10,), kwargs = {})
#   %exp_1 : Tensor "f32[2, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%neg,), kwargs = {})
#   %add_1 : Tensor "f32[2, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%exp_1, 1), kwargs = {})
#   %div : Tensor "f32[2, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_10, %add_1), kwargs = {})
#   %convert_element_type_11 : Tensor "bf16[2, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%div, torch.bfloat16), kwargs = {})
#   return %convert_element_type_11
triton_poi_fused_silu_1 = async_compile.triton('triton_poi_fused_silu_1', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 2048}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_silu_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 12288}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_silu_1(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2048
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (x0), xmask).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = -tmp1
    tmp3 = libdevice.exp(tmp2)
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = (tmp1 / tmp5)
    tmp7 = tmp6.to(tl.float32)
    tl.store(in_out_ptr0 + (x0), tmp7, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ks/cksormikubetxxeoogormaw25dogkn67wzcida5rm3y5cxxumwap.py
# Topologically Sorted Source Nodes: [transpose_1, contiguous_1], Original ATen: [aten.transpose, aten.clone]
# Source node to ATen node mapping:
#   contiguous_1 => clone_1
#   transpose_1 => permute_2
# Graph fragment:
#   %arg3_1 : Tensor "bf16[2, 64, 4][256, 4, 1]cuda:0" = PlaceHolder[target=arg3_1]
#   %permute_2 : Tensor "bf16[2, 4, 64][256, 1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg3_1, [0, 2, 1]), kwargs = {})
#   %clone_1 : Tensor "bf16[2, 4, 64][256, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_2,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_1
triton_poi_fused_clone_transpose_2 = async_compile.triton('triton_poi_fused_clone_transpose_2', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 8, 'x': 64}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_transpose_2', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 1024, 'x': 2048}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_transpose_2(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 8
    xnumel = 64
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = xindex < xnumel
    x2 = xindex
    y0 = (yindex % 4)
    y1 = yindex // 4
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (y0 + 4*x2 + 256*y1), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
    tl.store(out_ptr0 + (x2 + 64*y3), tmp0, xmask & ymask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ex/cexd73fz6arqcee6i4ejjouz3aayyhszvx3rotcwkk33fk4zvbtm.py
# Topologically Sorted Source Nodes: [mu, t_1, unsqueeze_4, cond, x, x_1, to_2, pow_1, variance, add_1, rsqrt, mul_6, hidden, hidden_states], Original ATen: [aten.view, aten.add, aten.unsqueeze, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_1 => add_5
#   cond => view_3
#   hidden => convert_element_type_25
#   hidden_states => mul_9
#   mu => view_4
#   mul_6 => mul_8
#   pow_1 => pow_1
#   rsqrt => rsqrt
#   t_1 => add_4
#   to_2 => convert_element_type_24
#   unsqueeze_4 => unsqueeze_4
#   variance => mean
#   x => view_1
#   x_1 => cat_2
# Graph fragment:
#   %arg16_1 : Tensor "bf16[2, 2048][2048, 1]cuda:0" = PlaceHolder[target=arg16_1]
#   %addmm_3 : Tensor "bf16[2, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm_3]
#   %addmm_5 : Tensor "bf16[2, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm_5]
#   %addmm_1 : Tensor "bf16[8, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm_1]
#   %addmm : Tensor "bf16[8, 1024][1024, 1]cuda:0" = PlaceHolder[target=addmm]
#   %cat_2 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=cat_2]
#   %buf13 : Tensor "f32[2, 11, 1][11, 1, 22]cuda:0" = PlaceHolder[target=buf13]
#   %arg19_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg19_1]
#   %view_4 : Tensor "bf16[2, 2, 1024][2048, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arg16_1, [2, -1, 1024]), kwargs = {})
#   %add_4 : Tensor "bf16[2, 1024][1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm_3, %addmm_5), kwargs = {})
#   %unsqueeze_4 : Tensor "bf16[2, 1, 1024][1024, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%add_4, 1), kwargs = {})
#   %view_3 : Tensor "bf16[2, 4, 1024][4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm_1, [2, 4, 1024]), kwargs = {})
#   %view_1 : Tensor "bf16[2, 4, 1024][4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%addmm, [2, 4, 1024]), kwargs = {})
#   %cat_2 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.cat.default](args = ([%view_4, %unsqueeze_4, %view_3, %view_1], 1), kwargs = {})
#   %convert_element_type_24 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%cat_2, torch.float32), kwargs = {})
#   %pow_1 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_24, 2), kwargs = {})
#   %mean : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
#   %add_5 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
#   %rsqrt : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_5,), kwargs = {})
#   %mul_8 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %rsqrt), kwargs = {})
#   %convert_element_type_25 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_8, torch.bfloat16), kwargs = {})
#   %mul_9 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_25, %arg19_1), kwargs = {})
#   return %cat_2,%buf13,%mul_9
triton_per_fused__to_copy_add_cat_mean_mul_pow_rsqrt_unsqueeze_view_3 = async_compile.triton('triton_per_fused__to_copy_add_cat_mean_mul_pow_rsqrt_unsqueeze_view_3', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 32, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'out_ptr0': '*bf16', 'out_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_cat_mean_mul_pow_rsqrt_unsqueeze_view_3', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 231424}}
)
@triton.jit
def triton_per_fused__to_copy_add_cat_mean_mul_pow_rsqrt_unsqueeze_view_3(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr0, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 22
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    x0 = (xindex % 11)
    r0_2 = r0_index
    x1 = xindex // 11
    x3 = xindex
    tmp40 = tl.load(in_ptr5 + (r0_2), None, eviction_policy='evict_last').to(tl.float32)
    tmp0 = x0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 2, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.load(in_ptr0 + (r0_2 + 1024*(x0) + 2048*x1), tmp4 & xmask, other=0.0).to(tl.float32)
    tmp6 = tmp0 >= tmp3
    tmp7 = tl.full([1, 1], 3, tl.int64)
    tmp8 = tmp0 < tmp7
    tmp9 = tmp6 & tmp8
    tmp10 = tl.load(in_ptr1 + (r0_2 + 1024*x1), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tl.load(in_ptr2 + (r0_2 + 1024*x1), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp12 = tmp10 + tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp0 >= tmp7
    tmp16 = tl.full([1, 1], 7, tl.int64)
    tmp17 = tmp0 < tmp16
    tmp18 = tmp15 & tmp17
    tmp19 = tl.load(in_ptr3 + (r0_2 + 1024*((-3) + x0) + 4096*x1), tmp18 & xmask, other=0.0).to(tl.float32)
    tmp20 = tmp0 >= tmp16
    tmp21 = tl.full([1, 1], 11, tl.int64)
    tmp22 = tmp0 < tmp21
    tmp23 = tl.load(in_ptr4 + (r0_2 + 1024*((-7) + x0) + 4096*x1), tmp20 & xmask, other=0.0).to(tl.float32)
    tmp24 = tl.where(tmp18, tmp19, tmp23)
    tmp25 = tl.where(tmp9, tmp14, tmp24)
    tmp26 = tl.where(tmp4, tmp5, tmp25)
    tmp27 = tmp26.to(tl.float32)
    tmp28 = tmp27 * tmp27
    tmp29 = tl.broadcast_to(tmp28, [XBLOCK, R0_BLOCK])
    tmp31 = tl.where(xmask, tmp29, 0)
    tmp32 = tl.sum(tmp31, 1)[:, None].to(tl.float32)
    tmp33 = tl.full([1, 1], 1024.0, tl.float32)
    tmp34 = (tmp32 / tmp33)
    tmp35 = tl.full([1, 1], 1e-05, tl.float32)
    tmp36 = tmp34 + tmp35
    tmp37 = libdevice.rsqrt(tmp36)
    tmp38 = tmp27 * tmp37
    tmp39 = tmp38.to(tl.float32)
    tmp41 = tmp39 * tmp40
    tl.store(out_ptr0 + (r0_2 + 1024*x3), tmp26, xmask)
    tl.store(out_ptr2 + (r0_2 + 1024*x3), tmp41, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/42/c42i4uisntm6yj7zvzwxoqpx5qrc2khkw57xahb2gppln6n2ob6u.py
# Topologically Sorted Source Nodes: [query_states, view_1, query_states_1, q, chunk, key_states, view_2, key_states_1, k, chunk_1, position_ids, cos_2, mul_8, neg, cat_3, sin_2, mul_9, q_embed, query_states_2, query_states_3, mul_10, neg_1, cat_4, mul_11, k_embed, key_states_2, key_states_3, value_states, view_3, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_3 => cat_3
#   cat_4 => cat_4
#   chunk => split
#   chunk_1 => split_1
#   cos_2 => index
#   k => convert_element_type_33
#   k_embed => add_7
#   key_states => view_8
#   key_states_1 => permute_12
#   key_states_2 => convert_element_type_35
#   key_states_3 => clone_3
#   mul_10 => mul_12
#   mul_11 => mul_13
#   mul_8 => mul_10
#   mul_9 => mul_11
#   neg => neg_2
#   neg_1 => neg_3
#   position_ids => iota_2
#   q => convert_element_type_32
#   q_embed => add_6
#   query_states => view_6
#   query_states_1 => permute_11
#   query_states_2 => convert_element_type_34
#   query_states_3 => clone_2
#   sin_2 => index_1
#   value_states => view_10
#   value_states_1 => permute_13
#   value_states_2 => clone_4
#   view_1 => view_11
#   view_2 => view_12
#   view_3 => view_13
# Graph fragment:
#   %mm : Tensor "bf16[22, 2048][2048, 1]cuda:0" = PlaceHolder[target=mm]
#   %arg17_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg17_1]
#   %arg18_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg18_1]
#   %view_6 : Tensor "bf16[2, 11, 2048][22528, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [2, 11, 2048]), kwargs = {})
#   %view_11 : Tensor "bf16[2, 11, 16, 128][22528, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [2, 11, 16, 128]), kwargs = {})
#   %permute_11 : Tensor "bf16[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_32 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_11, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_32, 64, -1), kwargs = {})
#   %view_8 : Tensor "bf16[2, 11, 256][2816, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [2, 11, 256]), kwargs = {})
#   %view_12 : Tensor "bf16[2, 11, 2, 128][2816, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [2, 11, 2, 128]), kwargs = {})
#   %permute_12 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_12, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_33 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_12, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_33, 64, -1), kwargs = {})
#   %iota_2 : Tensor "i64[11][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (11,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[11, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg17_1, [%iota_2]), kwargs = {})
#   %mul_10 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_32, %index), kwargs = {})
#   %neg_2 : Tensor "f32[2, 16, 11, 64][11264, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_3 : Tensor "f32[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_2, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[11, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg18_1, [%iota_2]), kwargs = {})
#   %mul_11 : Tensor "f32[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_3, %index_1), kwargs = {})
#   %add_6 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_10, %mul_11), kwargs = {})
#   %convert_element_type_34 : Tensor "bf16[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_6, torch.bfloat16), kwargs = {})
#   %clone_2 : Tensor "bf16[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_34,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_12 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_33, %index), kwargs = {})
#   %neg_3 : Tensor "f32[2, 2, 11, 64][1408, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_4 : Tensor "f32[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_3, %getitem_2], -1), kwargs = {})
#   %mul_13 : Tensor "f32[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_4, %index_1), kwargs = {})
#   %add_7 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_12, %mul_13), kwargs = {})
#   %convert_element_type_35 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.bfloat16), kwargs = {})
#   %clone_3 : Tensor "bf16[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_35,), kwargs = {memory_format: torch.contiguous_format})
#   %view_10 : Tensor "bf16[2, 11, 256][2816, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [2, 11, 256]), kwargs = {})
#   %view_13 : Tensor "bf16[2, 11, 2, 128][2816, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_10, [2, 11, 2, 128]), kwargs = {})
#   %permute_13 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_13, [0, 2, 1, 3]), kwargs = {})
#   %clone_4 : Tensor "bf16[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_13,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone_2, %clone_3, %clone_4), kwargs = {scale: 0.08838834764831843})
#   return %buf18
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 65536}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 456192}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 45056
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x4 = xindex
    x0 = (xindex % 128)
    x2 = ((xindex // 2048) % 11)
    x5 = xindex // 128
    x1 = ((xindex // 128) % 16)
    x3 = xindex // 22528
    tmp0 = tl.load(in_ptr0 + (x4), None).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x5 + (x0)), tmp9, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x5 + ((-64) + x0)), tmp15, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 1408*x1 + 22528*x3), tmp27, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/qo/cqoymccazb5ddcphxpypbfxh3pm5npit6h4jouuola2rpjmx6xwc.py
# Topologically Sorted Source Nodes: [query_states, view_1, query_states_1, q, chunk, key_states, view_2, key_states_1, k, chunk_1, position_ids, cos_2, mul_8, neg, cat_3, sin_2, mul_9, q_embed, query_states_2, query_states_3, mul_10, neg_1, cat_4, mul_11, k_embed, key_states_2, key_states_3, value_states, view_3, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_3 => cat_3
#   cat_4 => cat_4
#   chunk => split
#   chunk_1 => split_1
#   cos_2 => index
#   k => convert_element_type_33
#   k_embed => add_7
#   key_states => view_8
#   key_states_1 => permute_12
#   key_states_2 => convert_element_type_35
#   key_states_3 => clone_3
#   mul_10 => mul_12
#   mul_11 => mul_13
#   mul_8 => mul_10
#   mul_9 => mul_11
#   neg => neg_2
#   neg_1 => neg_3
#   position_ids => iota_2
#   q => convert_element_type_32
#   q_embed => add_6
#   query_states => view_6
#   query_states_1 => permute_11
#   query_states_2 => convert_element_type_34
#   query_states_3 => clone_2
#   sin_2 => index_1
#   value_states => view_10
#   value_states_1 => permute_13
#   value_states_2 => clone_4
#   view_1 => view_11
#   view_2 => view_12
#   view_3 => view_13
# Graph fragment:
#   %mm_1 : Tensor "bf16[22, 256][256, 1]cuda:0" = PlaceHolder[target=mm_1]
#   %arg17_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg17_1]
#   %arg18_1 : Tensor "bf16[32768, 128][128, 1]cuda:0" = PlaceHolder[target=arg18_1]
#   %view_6 : Tensor "bf16[2, 11, 2048][22528, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [2, 11, 2048]), kwargs = {})
#   %view_11 : Tensor "bf16[2, 11, 16, 128][22528, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [2, 11, 16, 128]), kwargs = {})
#   %permute_11 : Tensor "bf16[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_32 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_11, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_32, 64, -1), kwargs = {})
#   %view_8 : Tensor "bf16[2, 11, 256][2816, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [2, 11, 256]), kwargs = {})
#   %view_12 : Tensor "bf16[2, 11, 2, 128][2816, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [2, 11, 2, 128]), kwargs = {})
#   %permute_12 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_12, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_33 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_12, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_33, 64, -1), kwargs = {})
#   %iota_2 : Tensor "i64[11][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (11,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[11, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg17_1, [%iota_2]), kwargs = {})
#   %mul_10 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_32, %index), kwargs = {})
#   %neg_2 : Tensor "f32[2, 16, 11, 64][11264, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_3 : Tensor "f32[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_2, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[11, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg18_1, [%iota_2]), kwargs = {})
#   %mul_11 : Tensor "f32[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_3, %index_1), kwargs = {})
#   %add_6 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_10, %mul_11), kwargs = {})
#   %convert_element_type_34 : Tensor "bf16[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_6, torch.bfloat16), kwargs = {})
#   %clone_2 : Tensor "bf16[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_34,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_12 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_33, %index), kwargs = {})
#   %neg_3 : Tensor "f32[2, 2, 11, 64][1408, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_4 : Tensor "f32[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_3, %getitem_2], -1), kwargs = {})
#   %mul_13 : Tensor "f32[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_4, %index_1), kwargs = {})
#   %add_7 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_12, %mul_13), kwargs = {})
#   %convert_element_type_35 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.bfloat16), kwargs = {})
#   %clone_3 : Tensor "bf16[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_35,), kwargs = {memory_format: torch.contiguous_format})
#   %view_10 : Tensor "bf16[2, 11, 256][2816, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [2, 11, 256]), kwargs = {})
#   %view_13 : Tensor "bf16[2, 11, 2, 128][2816, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_10, [2, 11, 2, 128]), kwargs = {})
#   %permute_13 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_13, [0, 2, 1, 3]), kwargs = {})
#   %clone_4 : Tensor "bf16[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_13,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone_2, %clone_3, %clone_4), kwargs = {scale: 0.08838834764831843})
#   return %buf19
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 8192}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 61952}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5632
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x4 = xindex
    x0 = (xindex % 128)
    x2 = ((xindex // 256) % 11)
    x5 = xindex // 128
    x1 = ((xindex // 128) % 2)
    x3 = xindex // 2816
    tmp0 = tl.load(in_ptr0 + (x4), xmask).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp23 = tl.load(in_ptr2 + (x0 + 128*x2), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 * tmp3
    tmp5 = x0
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 64, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tl.load(in_ptr0 + (64 + 128*x5 + (x0)), tmp9 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp11 = tmp10.to(tl.float32)
    tmp12 = -tmp11
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp9, tmp12, tmp13)
    tmp15 = tmp5 >= tmp8
    tmp16 = tl.full([1], 128, tl.int64)
    tmp17 = tmp5 < tmp16
    tmp18 = tl.load(in_ptr0 + (128*x5 + ((-64) + x0)), tmp15 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp19 = tmp18.to(tl.float32)
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp15, tmp19, tmp20)
    tmp22 = tl.where(tmp9, tmp14, tmp21)
    tmp24 = tmp23.to(tl.float32)
    tmp25 = tmp22 * tmp24
    tmp26 = tmp4 + tmp25
    tmp27 = tmp26.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*x2 + 1408*x1 + 2816*x3), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/v4/cv44j2ubwaas35wo3dxrtdl57xobigblyw6dn7ytxjrkkd4iqw2w.py
# Topologically Sorted Source Nodes: [query_states, view_1, query_states_1, q, chunk, key_states, view_2, key_states_1, k, chunk_1, position_ids, cos_2, mul_8, neg, cat_3, sin_2, mul_9, q_embed, query_states_2, query_states_3, mul_10, neg_1, cat_4, mul_11, k_embed, key_states_2, key_states_3, value_states, view_3, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
# Source node to ATen node mapping:
#   attn_output => _scaled_dot_product_flash_attention
#   cat_3 => cat_3
#   cat_4 => cat_4
#   chunk => split
#   chunk_1 => split_1
#   cos_2 => index
#   k => convert_element_type_33
#   k_embed => add_7
#   key_states => view_8
#   key_states_1 => permute_12
#   key_states_2 => convert_element_type_35
#   key_states_3 => clone_3
#   mul_10 => mul_12
#   mul_11 => mul_13
#   mul_8 => mul_10
#   mul_9 => mul_11
#   neg => neg_2
#   neg_1 => neg_3
#   position_ids => iota_2
#   q => convert_element_type_32
#   q_embed => add_6
#   query_states => view_6
#   query_states_1 => permute_11
#   query_states_2 => convert_element_type_34
#   query_states_3 => clone_2
#   sin_2 => index_1
#   value_states => view_10
#   value_states_1 => permute_13
#   value_states_2 => clone_4
#   view_1 => view_11
#   view_2 => view_12
#   view_3 => view_13
# Graph fragment:
#   %mm_2 : Tensor "bf16[22, 256][256, 1]cuda:0" = PlaceHolder[target=mm_2]
#   %view_6 : Tensor "bf16[2, 11, 2048][22528, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm, [2, 11, 2048]), kwargs = {})
#   %view_11 : Tensor "bf16[2, 11, 16, 128][22528, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_6, [2, 11, 16, 128]), kwargs = {})
#   %permute_11 : Tensor "bf16[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_11, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_32 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_11, torch.float32), kwargs = {})
#   %split : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_32, 64, -1), kwargs = {})
#   %view_8 : Tensor "bf16[2, 11, 256][2816, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_1, [2, 11, 256]), kwargs = {})
#   %view_12 : Tensor "bf16[2, 11, 2, 128][2816, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_8, [2, 11, 2, 128]), kwargs = {})
#   %permute_12 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_12, [0, 2, 1, 3]), kwargs = {})
#   %convert_element_type_33 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%permute_12, torch.float32), kwargs = {})
#   %split_1 : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%convert_element_type_33, 64, -1), kwargs = {})
#   %iota_2 : Tensor "i64[11][1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.iota.default](args = (11,), kwargs = {start: 0, step: 1, dtype: torch.int64, device: cuda:0, requires_grad: False})
#   %index : Tensor "bf16[11, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg17_1, [%iota_2]), kwargs = {})
#   %mul_10 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_32, %index), kwargs = {})
#   %neg_2 : Tensor "f32[2, 16, 11, 64][11264, 64, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_1,), kwargs = {})
#   %cat_3 : Tensor "f32[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_2, %getitem], -1), kwargs = {})
#   %index_1 : Tensor "bf16[11, 128][128, 1]cuda:0"[num_users=24] = call_function[target=torch.ops.aten.index.Tensor](args = (%arg18_1, [%iota_2]), kwargs = {})
#   %mul_11 : Tensor "f32[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_3, %index_1), kwargs = {})
#   %add_6 : Tensor "f32[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_10, %mul_11), kwargs = {})
#   %convert_element_type_34 : Tensor "bf16[2, 16, 11, 128][22528, 128, 2048, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_6, torch.bfloat16), kwargs = {})
#   %clone_2 : Tensor "bf16[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_34,), kwargs = {memory_format: torch.contiguous_format})
#   %mul_12 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_33, %index), kwargs = {})
#   %neg_3 : Tensor "f32[2, 2, 11, 64][1408, 64, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%getitem_3,), kwargs = {})
#   %cat_4 : Tensor "f32[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_3, %getitem_2], -1), kwargs = {})
#   %mul_13 : Tensor "f32[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_4, %index_1), kwargs = {})
#   %add_7 : Tensor "f32[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_12, %mul_13), kwargs = {})
#   %convert_element_type_35 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_7, torch.bfloat16), kwargs = {})
#   %clone_3 : Tensor "bf16[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%convert_element_type_35,), kwargs = {memory_format: torch.contiguous_format})
#   %view_10 : Tensor "bf16[2, 11, 256][2816, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_2, [2, 11, 256]), kwargs = {})
#   %view_13 : Tensor "bf16[2, 11, 2, 128][2816, 256, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%view_10, [2, 11, 2, 128]), kwargs = {})
#   %permute_13 : Tensor "bf16[2, 2, 11, 128][2816, 128, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%view_13, [0, 2, 1, 3]), kwargs = {})
#   %clone_4 : Tensor "bf16[2, 2, 11, 128][2816, 1408, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_13,), kwargs = {memory_format: torch.contiguous_format})
#   %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%clone_2, %clone_3, %clone_4), kwargs = {scale: 0.08838834764831843})
#   return %buf20
triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6 = async_compile.triton('triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 8192}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 33792}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5632
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 11)
    x2 = ((xindex // 1408) % 2)
    x3 = xindex // 2816
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 256*x1 + 2816*x3), xmask).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/la/cla5s5hkkvywfb4von7ktayyu6vnhie3mgqzwbxijoypvpumz3vj.py
# Topologically Sorted Source Nodes: [transpose_5, attn_output_1], Original ATen: [aten.transpose, aten.clone]
# Source node to ATen node mapping:
#   attn_output_1 => clone_5
#   transpose_5 => permute_14
# Graph fragment:
#   %getitem_4 : Tensor "bf16[2, 16, 11, 128][22528, 1408, 128, 1]cuda:0" = PlaceHolder[target=getitem_4]
#   %permute_14 : Tensor "bf16[2, 11, 16, 128][22528, 128, 1408, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%getitem_4, [0, 2, 1, 3]), kwargs = {})
#   %clone_5 : Tensor "bf16[2, 11, 16, 128][22528, 2048, 128, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_14,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_5
triton_poi_fused_clone_transpose_7 = async_compile.triton('triton_poi_fused_clone_transpose_7', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 65536}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_transpose_7', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 270336}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clone_transpose_7(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 45056
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = (xindex % 128)
    x1 = ((xindex // 128) % 16)
    x2 = ((xindex // 2048) % 11)
    x3 = xindex // 22528
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + 128*x2 + 1408*x1 + 22528*x3), None).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/wh/cwhsqcapkxam7mfkpnrc56pwbpfog63d2t2sa5ncq5zge5pyvhm3.py
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, to_8, pow_2, variance_1, add_5, rsqrt_1, mul_12, hidden_1, hidden_states_2], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_5 => add_9
#   attn_output_3 => view_16
#   hidden_1 => convert_element_type_39
#   hidden_states_1 => add_8
#   hidden_states_2 => mul_15
#   mul_12 => mul_14
#   pow_2 => pow_2
#   rsqrt_1 => rsqrt_1
#   to_8 => convert_element_type_38
#   variance_1 => mean_1
# Graph fragment:
#   %cat_2 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=cat_2]
#   %mm_3 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %buf29 : Tensor "f32[2, 11, 1][11, 1, 22]cuda:0" = PlaceHolder[target=buf29]
#   %arg24_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg24_1]
#   %view_16 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [2, 11, 1024]), kwargs = {})
#   %add_8 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat_2, %view_16), kwargs = {})
#   %convert_element_type_38 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_8, torch.float32), kwargs = {})
#   %pow_2 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_38, 2), kwargs = {})
#   %mean_1 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
#   %add_9 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
#   %rsqrt_1 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_9,), kwargs = {})
#   %mul_14 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_8, %rsqrt_1), kwargs = {})
#   %convert_element_type_39 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_14, torch.bfloat16), kwargs = {})
#   %mul_15 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_39, %arg24_1), kwargs = {})
#   return %buf29,%mul_15
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 32, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 182272}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 22
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp16 = tl.load(in_ptr2 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp3 * tmp3
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK, R0_BLOCK])
    tmp7 = tl.where(xmask, tmp5, 0)
    tmp8 = tl.sum(tmp7, 1)[:, None].to(tl.float32)
    tmp9 = tl.full([1, 1], 1024.0, tl.float32)
    tmp10 = (tmp8 / tmp9)
    tmp11 = tl.full([1, 1], 1e-05, tl.float32)
    tmp12 = tmp10 + tmp11
    tmp13 = libdevice.rsqrt(tmp12)
    tmp14 = tmp3 * tmp13
    tmp15 = tmp14.to(tl.float32)
    tmp17 = tmp15 * tmp16
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp17, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/7g/c7gcru6notiznszbftvyvptdrca4bgdgzzsacjeiygcljbgv4eqo.py
# Topologically Sorted Source Nodes: [linear_10, silu_2, linear_11, mul_14], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
# Source node to ATen node mapping:
#   linear_10 => view_18
#   linear_11 => view_20
#   mul_14 => mul_16
#   silu_2 => add_10, convert_element_type_42, convert_element_type_43, div_2, exp_4, neg_4
# Graph fragment:
#   %mm_4 : Tensor "bf16[22, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_4]
#   %mm_5 : Tensor "bf16[22, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_5]
#   %view_18 : Tensor "bf16[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_4, [2, 11, 4096]), kwargs = {})
#   %convert_element_type_42 : Tensor "f32[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_18, torch.float32), kwargs = {})
#   %neg_4 : Tensor "f32[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%convert_element_type_42,), kwargs = {})
#   %exp_4 : Tensor "f32[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%neg_4,), kwargs = {})
#   %add_10 : Tensor "f32[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%exp_4, 1), kwargs = {})
#   %div_2 : Tensor "f32[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%convert_element_type_42, %add_10), kwargs = {})
#   %convert_element_type_43 : Tensor "bf16[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%div_2, torch.bfloat16), kwargs = {})
#   %view_20 : Tensor "bf16[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_5, [2, 11, 4096]), kwargs = {})
#   %mul_16 : Tensor "bf16[2, 11, 4096][45056, 4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_43, %view_20), kwargs = {})
#   return %mul_16
triton_poi_fused__unsafe_view_mul_silu_9 = async_compile.triton('triton_poi_fused__unsafe_view_mul_silu_9', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 131072}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__unsafe_view_mul_silu_9', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 720896}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__unsafe_view_mul_silu_9(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 90112
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
    tmp8 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = -tmp1
    tmp3 = libdevice.exp(tmp2)
    tmp4 = tl.full([1], 1.0, tl.float32)
    tmp5 = tmp3 + tmp4
    tmp6 = (tmp1 / tmp5)
    tmp7 = tmp6.to(tl.float32)
    tmp9 = tmp7 * tmp8
    tl.store(in_out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/fv/cfvkz7mee3bpd6g6gkloitiicyp5uqj7pzd5ijznomol6c3ocxvf.py
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_10, pow_3, variance_2, add_7, rsqrt_2, mul_15, hidden_2, hidden_states_5], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_7 => add_12
#   attn_output_3 => view_16
#   hidden_2 => convert_element_type_49
#   hidden_states_1 => add_8
#   hidden_states_3 => view_22
#   hidden_states_4 => add_11
#   hidden_states_5 => mul_18
#   mul_15 => mul_17
#   pow_3 => pow_3
#   rsqrt_2 => rsqrt_2
#   to_10 => convert_element_type_48
#   variance_2 => mean_2
# Graph fragment:
#   %cat_2 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=cat_2]
#   %mm_3 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %buf35 : Tensor "f32[2, 11, 1][11, 1, 22]cuda:0" = PlaceHolder[target=buf35]
#   %arg28_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg28_1]
#   %view_16 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [2, 11, 1024]), kwargs = {})
#   %add_8 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat_2, %view_16), kwargs = {})
#   %view_22 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [2, 11, 1024]), kwargs = {})
#   %add_11 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_8, %view_22), kwargs = {})
#   %convert_element_type_48 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_11, torch.float32), kwargs = {})
#   %pow_3 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_48, 2), kwargs = {})
#   %mean_2 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
#   %add_12 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-05), kwargs = {})
#   %rsqrt_2 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_12,), kwargs = {})
#   %mul_17 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_11, %rsqrt_2), kwargs = {})
#   %convert_element_type_49 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_17, torch.bfloat16), kwargs = {})
#   %mul_18 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_49, %arg28_1), kwargs = {})
#   return %buf35,%mul_18
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 32, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 227328}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 22
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp18 = tl.load(in_ptr3 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp5 = tmp4.to(tl.float32)
    tmp6 = tmp5 * tmp5
    tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
    tmp9 = tl.where(xmask, tmp7, 0)
    tmp10 = tl.sum(tmp9, 1)[:, None].to(tl.float32)
    tmp11 = tl.full([1, 1], 1024.0, tl.float32)
    tmp12 = (tmp10 / tmp11)
    tmp13 = tl.full([1, 1], 1e-05, tl.float32)
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp5 * tmp15
    tmp17 = tmp16.to(tl.float32)
    tmp19 = tmp17 * tmp18
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp19, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/jw/cjwung2bft2kcmta7schy3nlnfwldlwa3tfp3fc3xe6x4yii665p.py
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, to_16, pow_4, variance_3, add_11, rsqrt_3, mul_21, hidden_3, hidden_states_7], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_11 => add_16
#   attn_output_3 => view_16
#   attn_output_7 => view_34
#   hidden_3 => convert_element_type_63
#   hidden_states_1 => add_8
#   hidden_states_3 => view_22
#   hidden_states_4 => add_11
#   hidden_states_6 => add_15
#   hidden_states_7 => mul_24
#   mul_21 => mul_23
#   pow_4 => pow_4
#   rsqrt_3 => rsqrt_3
#   to_16 => convert_element_type_62
#   variance_3 => mean_3
# Graph fragment:
#   %cat_2 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=cat_2]
#   %mm_3 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %buf51 : Tensor "f32[2, 11, 1][11, 1, 22]cuda:0" = PlaceHolder[target=buf51]
#   %arg33_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg33_1]
#   %view_16 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [2, 11, 1024]), kwargs = {})
#   %add_8 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat_2, %view_16), kwargs = {})
#   %view_22 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [2, 11, 1024]), kwargs = {})
#   %add_11 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_8, %view_22), kwargs = {})
#   %view_34 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_10, [2, 11, 1024]), kwargs = {})
#   %add_15 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_34), kwargs = {})
#   %convert_element_type_62 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_15, torch.float32), kwargs = {})
#   %pow_4 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_62, 2), kwargs = {})
#   %mean_3 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_4, [-1], True), kwargs = {})
#   %add_16 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_3, 1e-05), kwargs = {})
#   %rsqrt_3 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_16,), kwargs = {})
#   %mul_23 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_15, %rsqrt_3), kwargs = {})
#   %convert_element_type_63 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_23, torch.bfloat16), kwargs = {})
#   %mul_24 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_63, %arg33_1), kwargs = {})
#   return %buf51,%mul_24
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 32, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 272384}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 22
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp20 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tmp7 * tmp7
    tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
    tmp11 = tl.where(xmask, tmp9, 0)
    tmp12 = tl.sum(tmp11, 1)[:, None].to(tl.float32)
    tmp13 = tl.full([1, 1], 1024.0, tl.float32)
    tmp14 = (tmp12 / tmp13)
    tmp15 = tl.full([1, 1], 1e-05, tl.float32)
    tmp16 = tmp14 + tmp15
    tmp17 = libdevice.rsqrt(tmp16)
    tmp18 = tmp7 * tmp17
    tmp19 = tmp18.to(tl.float32)
    tmp21 = tmp19 * tmp20
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp21, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/ks/ckstpjtlthynnlplc73wgbhsmdjsu735gvxxnkdoek26v6au5knr.py
# Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, to_18, pow_5, variance_4, add_13, rsqrt_4, mul_24, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
#   add_13 => add_19
#   attn_output_3 => view_16
#   attn_output_7 => view_34
#   hidden_4 => convert_element_type_73
#   hidden_states_1 => add_8
#   hidden_states_10 => mul_27
#   hidden_states_3 => view_22
#   hidden_states_4 => add_11
#   hidden_states_6 => add_15
#   hidden_states_8 => view_40
#   hidden_states_9 => add_18
#   mul_24 => mul_26
#   pow_5 => pow_5
#   rsqrt_4 => rsqrt_4
#   to_18 => convert_element_type_72
#   variance_4 => mean_4
# Graph fragment:
#   %cat_2 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=cat_2]
#   %mm_3 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_3]
#   %mm_6 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_6]
#   %mm_10 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_10]
#   %mm_13 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_13]
#   %add_18 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=add_18]
#   %buf58 : Tensor "f32[2, 11, 1][11, 1, 22]cuda:0" = PlaceHolder[target=buf58]
#   %arg37_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg37_1]
#   %view_16 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_3, [2, 11, 1024]), kwargs = {})
#   %add_8 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat_2, %view_16), kwargs = {})
#   %view_22 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_6, [2, 11, 1024]), kwargs = {})
#   %add_11 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_8, %view_22), kwargs = {})
#   %view_34 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_10, [2, 11, 1024]), kwargs = {})
#   %add_15 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_11, %view_34), kwargs = {})
#   %view_40 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_13, [2, 11, 1024]), kwargs = {})
#   %add_18 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_15, %view_40), kwargs = {})
#   %convert_element_type_72 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_18, torch.float32), kwargs = {})
#   %pow_5 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_72, 2), kwargs = {})
#   %mean_4 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_5, [-1], True), kwargs = {})
#   %add_19 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_4, 1e-05), kwargs = {})
#   %rsqrt_4 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_19,), kwargs = {})
#   %mul_26 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_18, %rsqrt_4), kwargs = {})
#   %convert_element_type_73 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_26, torch.bfloat16), kwargs = {})
#   %mul_27 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_73, %arg37_1), kwargs = {})
#   return %add_18,%buf58,%mul_27
triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 32, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 6, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 0, 'r0_': 407552}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 22
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp22 = tl.load(in_ptr4 + (r0_1), None, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp8 = tmp6 + tmp7
    tmp9 = tmp8.to(tl.float32)
    tmp10 = tmp9 * tmp9
    tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
    tmp13 = tl.where(xmask, tmp11, 0)
    tmp14 = tl.sum(tmp13, 1)[:, None].to(tl.float32)
    tmp15 = tl.full([1, 1], 1024.0, tl.float32)
    tmp16 = (tmp14 / tmp15)
    tmp17 = tl.full([1, 1], 1e-05, tl.float32)
    tmp18 = tmp16 + tmp17
    tmp19 = libdevice.rsqrt(tmp18)
    tmp20 = tmp9 * tmp19
    tmp21 = tmp20.to(tl.float32)
    tmp23 = tmp21 * tmp22
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp8, xmask)
    tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp23, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/yp/cypk4rcf54wp4nz7vhw43w737n3qsnadpioivle6dgpj36kzoshl.py
# Topologically Sorted Source Nodes: [attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, attn_output_47, hidden_states_56, hidden_states_58, hidden_states_59, to_98, pow_25, variance_24], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean]
# Source node to ATen node mapping:
#   attn_output_43 => view_196
#   attn_output_47 => view_214
#   hidden_states_51 => add_78
#   hidden_states_53 => view_202
#   hidden_states_54 => add_81
#   hidden_states_56 => add_85
#   hidden_states_58 => view_220
#   hidden_states_59 => add_88
#   pow_25 => pow_25
#   to_98 => convert_element_type_312
#   variance_24 => mean_24
# Graph fragment:
#   %add_74 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=add_74]
#   %mm_73 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_73]
#   %mm_76 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_76]
#   %mm_80 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_80]
#   %mm_83 : Tensor "bf16[22, 1024][1024, 1]cuda:0" = PlaceHolder[target=mm_83]
#   %add_88 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=add_88]
#   %view_196 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_73, [2, 11, 1024]), kwargs = {})
#   %add_78 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_74, %view_196), kwargs = {})
#   %view_202 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_76, [2, 11, 1024]), kwargs = {})
#   %add_81 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_78, %view_202), kwargs = {})
#   %view_214 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_80, [2, 11, 1024]), kwargs = {})
#   %add_85 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_81, %view_214), kwargs = {})
#   %view_220 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_83, [2, 11, 1024]), kwargs = {})
#   %add_88 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_85, %view_220), kwargs = {})
#   %convert_element_type_312 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_88, torch.float32), kwargs = {})
#   %pow_25 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 2), kwargs = {})
#   %mean_24 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_25, [-1], True), kwargs = {})
#   return %add_88,%buf283
triton_per_fused__to_copy__unsafe_view_add_mean_pow_13 = async_compile.triton('triton_per_fused__to_copy__unsafe_view_add_mean_pow_13', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints={'x': 32, 'r0_': 1024},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy__unsafe_view_add_mean_pow_13', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': None, 'atomic_add_found': False, 'num_load': 5, 'num_store': 2, 'num_reduction': 1, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 176, 'r0_': 315392}}
)
@triton.jit
def triton_per_fused__to_copy__unsafe_view_add_mean_pow_13(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 22
    r0_numel = 1024
    R0_BLOCK: tl.constexpr = 1024
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = tl.full([R0_BLOCK], True, tl.int1)[None, :]
    roffset = r0_offset
    rindex = r0_index
    r0_1 = r0_index
    x0 = xindex
    tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp3 = tl.load(in_ptr1 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp5 = tl.load(in_ptr2 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r0_1 + 1024*x0), xmask, other=0.0).to(tl.float32)
    tmp2 = tmp0 + tmp1
    tmp4 = tmp2 + tmp3
    tmp6 = tmp4 + tmp5
    tmp8 = tmp6 + tmp7
    tmp9 = tmp8.to(tl.float32)
    tmp10 = tmp9 * tmp9
    tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
    tmp13 = tl.where(xmask, tmp11, 0)
    tmp14 = tl.sum(tmp13, 1)[:, None].to(tl.float32)
    tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp8, xmask)
    tl.store(out_ptr0 + (x0), tmp14, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/p6/cp6natia47ef5pbibslpkehumo4p7zvikt5dqtja2pdxf7du5bln.py
# Topologically Sorted Source Nodes: [to_98, pow_25, variance_24, add_73, rsqrt_24, mul_114, hidden_24, hidden_states_60, hidden_25, hidden_26], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.slice, aten.clone]
# Source node to ATen node mapping:
#   add_73 => add_89
#   hidden_24 => convert_element_type_313
#   hidden_25 => slice_1
#   hidden_26 => clone_50
#   hidden_states_60 => mul_117
#   mul_114 => mul_116
#   pow_25 => pow_25
#   rsqrt_24 => rsqrt_24
#   to_98 => convert_element_type_312
#   variance_24 => mean_24
# Graph fragment:
#   %add_88 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0" = PlaceHolder[target=add_88]
#   %buf283 : Tensor "f32[2, 11, 1][11, 1, 22]cuda:0" = PlaceHolder[target=buf283]
#   %arg127_1 : Tensor "bf16[1024][1]cuda:0" = PlaceHolder[target=arg127_1]
#   %convert_element_type_312 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_88, torch.float32), kwargs = {})
#   %pow_25 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 2), kwargs = {})
#   %mean_24 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_25, [-1], True), kwargs = {})
#   %add_89 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_24, 1e-05), kwargs = {})
#   %rsqrt_24 : Tensor "f32[2, 11, 1][11, 1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_89,), kwargs = {})
#   %mul_116 : Tensor "f32[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_88, %rsqrt_24), kwargs = {})
#   %convert_element_type_313 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_116, torch.bfloat16), kwargs = {})
#   %mul_117 : Tensor "bf16[2, 11, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_313, %arg127_1), kwargs = {})
#   %slice_1 : Tensor "bf16[2, 4, 1024][11264, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mul_117, 1, 7, 9223372036854775807), kwargs = {})
#   %clone_50 : Tensor "bf16[2, 4, 1024][4096, 1024, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%slice_1,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_50
triton_poi_fused__to_copy_add_clone_mean_mul_pow_rsqrt_slice_14 = async_compile.triton('triton_poi_fused__to_copy_add_clone_mean_mul_pow_rsqrt_slice_14', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 8192}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_clone_mean_mul_pow_rsqrt_slice_14', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 51200}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_clone_mean_mul_pow_rsqrt_slice_14(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 8192
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x2 = xindex // 4096
    x3 = (xindex % 4096)
    x1 = ((xindex // 1024) % 4)
    x0 = (xindex % 1024)
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (7168 + x3 + 11264*x2), None).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (7 + x1 + 11*x2), None, eviction_policy='evict_last')
    tmp10 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tl.full([1], 1024.0, tl.float32)
    tmp4 = (tmp2 / tmp3)
    tmp5 = tl.full([1], 1e-05, tl.float32)
    tmp6 = tmp4 + tmp5
    tmp7 = libdevice.rsqrt(tmp6)
    tmp8 = tmp1 * tmp7
    tmp9 = tmp8.to(tl.float32)
    tmp11 = tmp9 * tmp10
    tl.store(out_ptr0 + (x4), tmp11, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_shingokuga/4l/c4lxwmmxan7una4r65yk4fnd72bebmvdjk3oplrxxsdn4pjex7ku.py
# Topologically Sorted Source Nodes: [hidden_26, transpose_50, contiguous_50], Original ATen: [aten._unsafe_view, aten.add, aten.transpose, aten.clone]
# Source node to ATen node mapping:
#   contiguous_50 => clone_51
#   hidden_26 => add_90, view_222
#   transpose_50 => permute_141
# Graph fragment:
#   %mm_84 : Tensor "bf16[8, 64][64, 1]cuda:0" = PlaceHolder[target=mm_84]
#   %arg129_1 : Tensor "bf16[64][1]cuda:0" = PlaceHolder[target=arg129_1]
#   %view_222 : Tensor "bf16[2, 4, 64][256, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_84, [2, 4, 64]), kwargs = {})
#   %add_90 : Tensor "bf16[2, 4, 64][256, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_222, %arg129_1), kwargs = {})
#   %permute_141 : Tensor "bf16[2, 64, 4][256, 1, 64]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%add_90, [0, 2, 1]), kwargs = {})
#   %clone_51 : Tensor "bf16[2, 64, 4][256, 4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%permute_141,), kwargs = {memory_format: torch.contiguous_format})
#   return %clone_51
triton_poi_fused__unsafe_view_add_clone_transpose_15 = async_compile.triton('triton_poi_fused__unsafe_view_add_clone_transpose_15', '''
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 128, 'x': 4}, tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__unsafe_view_add_clone_transpose_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 1152, 'x': 1024}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__unsafe_view_add_clone_transpose_15(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 128
    xnumel = 4
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = xindex < xnumel
    x2 = xindex
    y0 = (yindex % 64)
    y1 = yindex // 64
    y3 = yindex
    tmp0 = tl.load(in_ptr0 + (y0 + 64*x2 + 256*y1), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (y0), ymask, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x2 + 4*y3), tmp2, xmask & ymask)
''', device_str='cuda')

def partition_0(args):
    arg6_1, arg8_1, arg7_1, arg10_1, arg9_1, arg11_1, arg13_1, arg12_1, arg15_1, arg14_1, arg3_1, arg5_1, arg4_1, arg0_1, arg2_1, arg1_1, arg16_1, arg19_1, arg20_1, arg21_1, arg22_1, arg17_1, arg18_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1 = args
    args.clear()
    assert_size_stride(arg6_1, (2, ), (1, ))
    assert_size_stride(arg8_1, (1024, ), (1, ))
    assert_size_stride(arg7_1, (1024, 1024), (1024, 1))
    assert_size_stride(arg10_1, (1024, ), (1, ))
    assert_size_stride(arg9_1, (1024, 1024), (1024, 1))
    assert_size_stride(arg11_1, (2, ), (1, ))
    assert_size_stride(arg13_1, (1024, ), (1, ))
    assert_size_stride(arg12_1, (1024, 1024), (1024, 1))
    assert_size_stride(arg15_1, (1024, ), (1, ))
    assert_size_stride(arg14_1, (1024, 1024), (1024, 1))
    assert_size_stride(arg3_1, (2, 64, 4), (256, 4, 1))
    assert_size_stride(arg5_1, (1024, ), (1, ))
    assert_size_stride(arg4_1, (1024, 64), (64, 1))
    assert_size_stride(arg0_1, (2, 64, 4), (256, 4, 1))
    assert_size_stride(arg2_1, (1024, ), (1, ))
    assert_size_stride(arg1_1, (1024, 64), (64, 1))
    assert_size_stride(arg16_1, (2, 2048), (2048, 1))
    assert_size_stride(arg19_1, (1024, ), (1, ))
    assert_size_stride(arg20_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg21_1, (256, 1024), (1024, 1))
    assert_size_stride(arg22_1, (256, 1024), (1024, 1))
    assert_size_stride(arg17_1, (32768, 128), (128, 1))
    assert_size_stride(arg18_1, (32768, 128), (128, 1))
    assert_size_stride(arg23_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg24_1, (1024, ), (1, ))
    assert_size_stride(arg25_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg26_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg27_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg28_1, (1024, ), (1, ))
    assert_size_stride(arg29_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg30_1, (256, 1024), (1024, 1))
    assert_size_stride(arg31_1, (256, 1024), (1024, 1))
    assert_size_stride(arg32_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg33_1, (1024, ), (1, ))
    assert_size_stride(arg34_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg35_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg36_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg37_1, (1024, ), (1, ))
    assert_size_stride(arg38_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg39_1, (256, 1024), (1024, 1))
    assert_size_stride(arg40_1, (256, 1024), (1024, 1))
    assert_size_stride(arg41_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg42_1, (1024, ), (1, ))
    assert_size_stride(arg43_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg44_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg45_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg46_1, (1024, ), (1, ))
    assert_size_stride(arg47_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg48_1, (256, 1024), (1024, 1))
    assert_size_stride(arg49_1, (256, 1024), (1024, 1))
    assert_size_stride(arg50_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg51_1, (1024, ), (1, ))
    assert_size_stride(arg52_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg53_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg54_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg55_1, (1024, ), (1, ))
    assert_size_stride(arg56_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg57_1, (256, 1024), (1024, 1))
    assert_size_stride(arg58_1, (256, 1024), (1024, 1))
    assert_size_stride(arg59_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg60_1, (1024, ), (1, ))
    assert_size_stride(arg61_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg62_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg63_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg64_1, (1024, ), (1, ))
    assert_size_stride(arg65_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg66_1, (256, 1024), (1024, 1))
    assert_size_stride(arg67_1, (256, 1024), (1024, 1))
    assert_size_stride(arg68_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg69_1, (1024, ), (1, ))
    assert_size_stride(arg70_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg71_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg72_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg73_1, (1024, ), (1, ))
    assert_size_stride(arg74_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg75_1, (256, 1024), (1024, 1))
    assert_size_stride(arg76_1, (256, 1024), (1024, 1))
    assert_size_stride(arg77_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg78_1, (1024, ), (1, ))
    assert_size_stride(arg79_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg80_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg81_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg82_1, (1024, ), (1, ))
    assert_size_stride(arg83_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg84_1, (256, 1024), (1024, 1))
    assert_size_stride(arg85_1, (256, 1024), (1024, 1))
    assert_size_stride(arg86_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg87_1, (1024, ), (1, ))
    assert_size_stride(arg88_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg89_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg90_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg91_1, (1024, ), (1, ))
    assert_size_stride(arg92_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg93_1, (256, 1024), (1024, 1))
    assert_size_stride(arg94_1, (256, 1024), (1024, 1))
    assert_size_stride(arg95_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg96_1, (1024, ), (1, ))
    assert_size_stride(arg97_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg98_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg99_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg100_1, (1024, ), (1, ))
    assert_size_stride(arg101_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg102_1, (256, 1024), (1024, 1))
    assert_size_stride(arg103_1, (256, 1024), (1024, 1))
    assert_size_stride(arg104_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg105_1, (1024, ), (1, ))
    assert_size_stride(arg106_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg107_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg108_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg109_1, (1024, ), (1, ))
    assert_size_stride(arg110_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg111_1, (256, 1024), (1024, 1))
    assert_size_stride(arg112_1, (256, 1024), (1024, 1))
    assert_size_stride(arg113_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg114_1, (1024, ), (1, ))
    assert_size_stride(arg115_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg116_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg117_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg118_1, (1024, ), (1, ))
    assert_size_stride(arg119_1, (2048, 1024), (1024, 1))
    assert_size_stride(arg120_1, (256, 1024), (1024, 1))
    assert_size_stride(arg121_1, (256, 1024), (1024, 1))
    assert_size_stride(arg122_1, (1024, 2048), (2048, 1))
    assert_size_stride(arg123_1, (1024, ), (1, ))
    assert_size_stride(arg124_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg125_1, (4096, 1024), (1024, 1))
    assert_size_stride(arg126_1, (1024, 4096), (4096, 1))
    assert_size_stride(arg127_1, (1024, ), (1, ))
    assert_size_stride(arg128_1, (64, 1024), (1024, 1))
    assert_size_stride(arg129_1, (64, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((2, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [unsqueeze, mul_1, arange, mul, emb, unsqueeze_1, emb_1, sin, cos, emb_2], Original ATen: [aten.unsqueeze, aten.mul, aten.arange, aten.exp, aten.sin, aten.cos, aten.cat]
        stream0 = get_raw_stream(0)
        triton_poi_fused_arange_cat_cos_exp_mul_sin_unsqueeze_0.run(arg6_1, buf0, 2048, stream=stream0)
        del arg6_1
        buf1 = empty_strided_cuda((2, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [unsqueeze, mul_1, arange, mul, emb, unsqueeze_1, emb_1, sin, cos, emb_2, sample], Original ATen: [aten.unsqueeze, aten.mul, aten.arange, aten.exp, aten.sin, aten.cos, aten.cat, aten.t, aten.addmm]
        extern_kernels.addmm(arg8_1, buf0, reinterpret_tensor(arg7_1, (1024, 1024), (1, 1024), 0), alpha=1, beta=1, out=buf1)
        del arg7_1
        del arg8_1
        buf2 = buf1; del buf1  # reuse
        # Topologically Sorted Source Nodes: [sample_1], Original ATen: [aten.silu]
        stream0 = get_raw_stream(0)
        triton_poi_fused_silu_1.run(buf2, 2048, stream=stream0)
        buf3 = buf0; del buf0  # reuse
        # Topologically Sorted Source Nodes: [sample_1, sample_2], Original ATen: [aten.silu, aten.t, aten.addmm]
        extern_kernels.addmm(arg10_1, buf2, reinterpret_tensor(arg9_1, (1024, 1024), (1, 1024), 0), alpha=1, beta=1, out=buf3)
        del arg10_1
        del arg9_1
        buf4 = buf2; del buf2  # reuse
        # Topologically Sorted Source Nodes: [unsqueeze_2, mul_4, arange_1, mul_3, emb_3, unsqueeze_3, emb_4, sin_1, cos_1, emb_5], Original ATen: [aten.unsqueeze, aten.mul, aten.arange, aten.exp, aten.sin, aten.cos, aten.cat]
        stream0 = get_raw_stream(0)
        triton_poi_fused_arange_cat_cos_exp_mul_sin_unsqueeze_0.run(arg11_1, buf4, 2048, stream=stream0)
        del arg11_1
        buf5 = empty_strided_cuda((2, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [unsqueeze_2, mul_4, arange_1, mul_3, emb_3, unsqueeze_3, emb_4, sin_1, cos_1, emb_5, sample_3], Original ATen: [aten.unsqueeze, aten.mul, aten.arange, aten.exp, aten.sin, aten.cos, aten.cat, aten.t, aten.addmm]
        extern_kernels.addmm(arg13_1, buf4, reinterpret_tensor(arg12_1, (1024, 1024), (1, 1024), 0), alpha=1, beta=1, out=buf5)
        del arg12_1
        del arg13_1
        buf6 = buf5; del buf5  # reuse
        # Topologically Sorted Source Nodes: [sample_4], Original ATen: [aten.silu]
        stream0 = get_raw_stream(0)
        triton_poi_fused_silu_1.run(buf6, 2048, stream=stream0)
        buf7 = buf4; del buf4  # reuse
        # Topologically Sorted Source Nodes: [sample_4, sample_5], Original ATen: [aten.silu, aten.t, aten.addmm]
        extern_kernels.addmm(arg15_1, buf6, reinterpret_tensor(arg14_1, (1024, 1024), (1, 1024), 0), alpha=1, beta=1, out=buf7)
        del arg14_1
        del arg15_1
        del buf6
        buf8 = empty_strided_cuda((2, 4, 64), (256, 64, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [transpose_1, contiguous_1], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_2.run(arg3_1, buf8, 8, 64, stream=stream0)
        del arg3_1
        buf9 = empty_strided_cuda((8, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [transpose_1, contiguous_1, cond], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.addmm]
        extern_kernels.addmm(arg5_1, reinterpret_tensor(buf8, (8, 64), (64, 1), 0), reinterpret_tensor(arg4_1, (64, 1024), (1, 64), 0), alpha=1, beta=1, out=buf9)
        del arg4_1
        del arg5_1
        buf10 = buf8; del buf8  # reuse
        # Topologically Sorted Source Nodes: [transpose, contiguous], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_2.run(arg0_1, buf10, 8, 64, stream=stream0)
        del arg0_1
        buf11 = empty_strided_cuda((8, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [transpose, contiguous, x], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.addmm]
        extern_kernels.addmm(arg2_1, reinterpret_tensor(buf10, (8, 64), (64, 1), 0), reinterpret_tensor(arg1_1, (64, 1024), (1, 64), 0), alpha=1, beta=1, out=buf11)
        del arg1_1
        del arg2_1
        del buf10
        buf12 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        buf14 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [mu, t_1, unsqueeze_4, cond, x, x_1, to_2, pow_1, variance, add_1, rsqrt, mul_6, hidden, hidden_states], Original ATen: [aten.view, aten.add, aten.unsqueeze, aten.cat, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy_add_cat_mean_mul_pow_rsqrt_unsqueeze_view_3.run(arg16_1, buf3, buf7, buf9, buf11, arg19_1, buf12, buf14, 22, 1024, stream=stream0)
        del arg16_1
        del arg19_1
        del buf11
        del buf3
        del buf7
        del buf9
        buf15 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_2, pow_1, variance, add_1, rsqrt, mul_6, hidden, hidden_states, query_states], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf14, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg20_1, (1024, 2048), (1, 1024), 0), out=buf15)
        del arg20_1
        buf16 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf14, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg21_1, (1024, 256), (1, 1024), 0), out=buf16)
        del arg21_1
        buf17 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf14, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg22_1, (1024, 256), (1, 1024), 0), out=buf17)
        del arg22_1
        buf18 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view_1, query_states_1, q, chunk, key_states, view_2, key_states_1, k, chunk_1, position_ids, cos_2, mul_8, neg, cat_3, sin_2, mul_9, q_embed, query_states_2, query_states_3, mul_10, neg_1, cat_4, mul_11, k_embed, key_states_2, key_states_3, value_states, view_3, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf15, arg17_1, arg18_1, buf18, 45056, stream=stream0)
        del buf15
        buf19 = empty_strided_cuda((2, 2, 11, 128), (2816, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states, view_1, query_states_1, q, chunk, key_states, view_2, key_states_1, k, chunk_1, position_ids, cos_2, mul_8, neg, cat_3, sin_2, mul_9, q_embed, query_states_2, query_states_3, mul_10, neg_1, cat_4, mul_11, k_embed, key_states_2, key_states_3, value_states, view_3, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf16, arg17_1, arg18_1, buf19, 5632, stream=stream0)
        buf20 = reinterpret_tensor(buf16, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf16  # reuse
        # Topologically Sorted Source Nodes: [query_states, view_1, query_states_1, q, chunk, key_states, view_2, key_states_1, k, chunk_1, position_ids, cos_2, mul_8, neg, cat_3, sin_2, mul_9, q_embed, query_states_2, query_states_3, mul_10, neg_1, cat_4, mul_11, k_embed, key_states_2, key_states_3, value_states, view_3, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf17, buf20, 5632, stream=stream0)
        # Topologically Sorted Source Nodes: [query_states, view_1, query_states_1, q, chunk, key_states, view_2, key_states_1, k, chunk_1, position_ids, cos_2, mul_8, neg, cat_3, sin_2, mul_9, q_embed, query_states_2, query_states_3, mul_10, neg_1, cat_4, mul_11, k_embed, key_states_2, key_states_3, value_states, view_3, value_states_1, value_states_2, attn_output], Original ATen: [aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.arange, aten.index, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf21 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf18, buf19, buf20, scale=0.08838834764831843)
        buf22 = buf21[0]
        assert_size_stride(buf22, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf22, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf21
        buf27 = reinterpret_tensor(buf18, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf18  # reuse
        # Topologically Sorted Source Nodes: [transpose_5, attn_output_1], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf22, buf27, 45056, stream=stream0)
        del buf22
        buf28 = reinterpret_tensor(buf14, (22, 1024), (1024, 1), 0); del buf14  # reuse
        # Topologically Sorted Source Nodes: [transpose_5, attn_output_1, attn_output_2, attn_output_3], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf27, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg23_1, (2048, 1024), (1, 2048), 0), out=buf28)
        del arg23_1
        del buf27
        buf30 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, to_8, pow_2, variance_1, add_5, rsqrt_1, mul_12, hidden_1, hidden_states_2], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf12, buf28, arg24_1, buf30, 22, 1024, stream=stream0)
        del arg24_1
        buf31 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, to_8, pow_2, variance_1, add_5, rsqrt_1, mul_12, hidden_1, hidden_states_2, linear_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf30, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg25_1, (1024, 4096), (1, 1024), 0), out=buf31)
        del arg25_1
        buf32 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_11], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf30, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg26_1, (1024, 4096), (1, 1024), 0), out=buf32)
        del arg26_1
        buf33 = reinterpret_tensor(buf31, (2, 11, 4096), (45056, 4096, 1), 0); del buf31  # reuse
        # Topologically Sorted Source Nodes: [linear_10, silu_2, linear_11, mul_14], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf33, buf32, 90112, stream=stream0)
        del buf32
        buf34 = reinterpret_tensor(buf30, (22, 1024), (1024, 1), 0); del buf30  # reuse
        # Topologically Sorted Source Nodes: [linear_10, silu_2, linear_11, mul_14, hidden_states_3], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf33, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg27_1, (4096, 1024), (1, 4096), 0), out=buf34)
        del arg27_1
        buf36 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, to_10, pow_3, variance_2, add_7, rsqrt_2, mul_15, hidden_2, hidden_states_5], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf12, buf28, buf34, arg28_1, buf36, 22, 1024, stream=stream0)
        del arg28_1
        buf37 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_4], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf36, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg29_1, (1024, 2048), (1, 1024), 0), out=buf37)
        del arg29_1
        buf38 = reinterpret_tensor(buf20, (22, 256), (256, 1), 0); del buf20  # reuse
        # Topologically Sorted Source Nodes: [key_states_4], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf36, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg30_1, (1024, 256), (1, 1024), 0), out=buf38)
        del arg30_1
        buf39 = reinterpret_tensor(buf19, (22, 256), (256, 1), 0); del buf19  # reuse
        # Topologically Sorted Source Nodes: [value_states_3], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf36, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg31_1, (1024, 256), (1, 1024), 0), out=buf39)
        del arg31_1
        buf40 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_4, view_4, query_states_5, q_1, chunk_2, key_states_4, view_5, key_states_5, k_1, chunk_3, mul_17, neg_2, cat_5, mul_18, q_embed_1, query_states_6, query_states_7, mul_19, neg_3, cat_6, mul_20, k_embed_1, key_states_6, key_states_7, value_states_3, view_6, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf37, arg17_1, arg18_1, buf40, 45056, stream=stream0)
        del buf37
        buf41 = reinterpret_tensor(buf17, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf17  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_4, view_4, query_states_5, q_1, chunk_2, key_states_4, view_5, key_states_5, k_1, chunk_3, mul_17, neg_2, cat_5, mul_18, q_embed_1, query_states_6, query_states_7, mul_19, neg_3, cat_6, mul_20, k_embed_1, key_states_6, key_states_7, value_states_3, view_6, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf38, arg17_1, arg18_1, buf41, 5632, stream=stream0)
        buf42 = reinterpret_tensor(buf38, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf38  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_4, view_4, query_states_5, q_1, chunk_2, key_states_4, view_5, key_states_5, k_1, chunk_3, mul_17, neg_2, cat_5, mul_18, q_embed_1, query_states_6, query_states_7, mul_19, neg_3, cat_6, mul_20, k_embed_1, key_states_6, key_states_7, value_states_3, view_6, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf39, buf42, 5632, stream=stream0)
        del buf39
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_4, view_4, query_states_5, q_1, chunk_2, key_states_4, view_5, key_states_5, k_1, chunk_3, mul_17, neg_2, cat_5, mul_18, q_embed_1, query_states_6, query_states_7, mul_19, neg_3, cat_6, mul_20, k_embed_1, key_states_6, key_states_7, value_states_3, view_6, value_states_4, value_states_5, attn_output_4], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf43 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf40, buf41, buf42, scale=0.08838834764831843)
        del buf41
        del buf42
        buf44 = buf43[0]
        assert_size_stride(buf44, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf44, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf43
        buf49 = reinterpret_tensor(buf40, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf40  # reuse
        # Topologically Sorted Source Nodes: [transpose_9, attn_output_5], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf44, buf49, 45056, stream=stream0)
        del buf44
        buf50 = reinterpret_tensor(buf36, (22, 1024), (1024, 1), 0); del buf36  # reuse
        # Topologically Sorted Source Nodes: [transpose_9, attn_output_5, attn_output_6, attn_output_7], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf49, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg32_1, (2048, 1024), (1, 2048), 0), out=buf50)
        del arg32_1
        del buf49
        buf52 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, to_16, pow_4, variance_3, add_11, rsqrt_3, mul_21, hidden_3, hidden_states_7], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf12, buf28, buf34, buf50, arg33_1, buf52, 22, 1024, stream=stream0)
        del arg33_1
        buf53 = reinterpret_tensor(buf33, (22, 4096), (4096, 1), 0); del buf33  # reuse
        # Topologically Sorted Source Nodes: [linear_17], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf52, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg34_1, (1024, 4096), (1, 1024), 0), out=buf53)
        del arg34_1
        buf54 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_18], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf52, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg35_1, (1024, 4096), (1, 1024), 0), out=buf54)
        del arg35_1
        del buf52
        buf55 = reinterpret_tensor(buf53, (2, 11, 4096), (45056, 4096, 1), 0); del buf53  # reuse
        # Topologically Sorted Source Nodes: [linear_17, silu_3, linear_18, mul_23], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf55, buf54, 90112, stream=stream0)
        buf56 = empty_strided_cuda((22, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_17, silu_3, linear_18, mul_23, hidden_states_8], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf55, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg36_1, (4096, 1024), (1, 4096), 0), out=buf56)
        del arg36_1
        buf57 = buf12; del buf12  # reuse
        buf59 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_3, hidden_states_1, hidden_states_3, hidden_states_4, attn_output_7, hidden_states_6, hidden_states_8, hidden_states_9, to_18, pow_5, variance_4, add_13, rsqrt_4, mul_24, hidden_4, hidden_states_10], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf57, buf28, buf34, buf50, buf56, arg37_1, buf59, 22, 1024, stream=stream0)
        del arg37_1
        del buf28
        del buf34
        del buf50
        buf60 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_18, pow_5, variance_4, add_13, rsqrt_4, mul_24, hidden_4, hidden_states_10, query_states_8], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf59, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg38_1, (1024, 2048), (1, 1024), 0), out=buf60)
        del arg38_1
        buf61 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_8], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf59, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg39_1, (1024, 256), (1, 1024), 0), out=buf61)
        del arg39_1
        buf62 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_6], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf59, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg40_1, (1024, 256), (1, 1024), 0), out=buf62)
        del arg40_1
        buf63 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_8, view_7, query_states_9, q_2, chunk_4, key_states_8, view_8, key_states_9, k_2, chunk_5, mul_26, neg_4, cat_7, mul_27, q_embed_2, query_states_10, query_states_11, mul_28, neg_5, cat_8, mul_29, k_embed_2, key_states_10, key_states_11, value_states_6, view_9, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf60, arg17_1, arg18_1, buf63, 45056, stream=stream0)
        del buf60
        buf64 = empty_strided_cuda((2, 2, 11, 128), (2816, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_8, view_7, query_states_9, q_2, chunk_4, key_states_8, view_8, key_states_9, k_2, chunk_5, mul_26, neg_4, cat_7, mul_27, q_embed_2, query_states_10, query_states_11, mul_28, neg_5, cat_8, mul_29, k_embed_2, key_states_10, key_states_11, value_states_6, view_9, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf61, arg17_1, arg18_1, buf64, 5632, stream=stream0)
        buf65 = reinterpret_tensor(buf61, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf61  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_8, view_7, query_states_9, q_2, chunk_4, key_states_8, view_8, key_states_9, k_2, chunk_5, mul_26, neg_4, cat_7, mul_27, q_embed_2, query_states_10, query_states_11, mul_28, neg_5, cat_8, mul_29, k_embed_2, key_states_10, key_states_11, value_states_6, view_9, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf62, buf65, 5632, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_8, view_7, query_states_9, q_2, chunk_4, key_states_8, view_8, key_states_9, k_2, chunk_5, mul_26, neg_4, cat_7, mul_27, q_embed_2, query_states_10, query_states_11, mul_28, neg_5, cat_8, mul_29, k_embed_2, key_states_10, key_states_11, value_states_6, view_9, value_states_7, value_states_8, attn_output_8], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf66 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf63, buf64, buf65, scale=0.08838834764831843)
        buf67 = buf66[0]
        assert_size_stride(buf67, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf67, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf66
        buf72 = reinterpret_tensor(buf63, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf63  # reuse
        # Topologically Sorted Source Nodes: [transpose_13, attn_output_9], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf67, buf72, 45056, stream=stream0)
        del buf67
        buf73 = reinterpret_tensor(buf59, (22, 1024), (1024, 1), 0); del buf59  # reuse
        # Topologically Sorted Source Nodes: [transpose_13, attn_output_9, attn_output_10, attn_output_11], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf72, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg41_1, (2048, 1024), (1, 2048), 0), out=buf73)
        del arg41_1
        del buf72
        buf75 = reinterpret_tensor(buf56, (2, 11, 1024), (11264, 1024, 1), 0); del buf56  # reuse
        # Topologically Sorted Source Nodes: [attn_output_11, hidden_states_11, to_24, pow_6, variance_5, add_17, rsqrt_5, mul_30, hidden_5, hidden_states_12], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf57, buf73, arg42_1, buf75, 22, 1024, stream=stream0)
        del arg42_1
        buf76 = reinterpret_tensor(buf55, (22, 4096), (4096, 1), 0); del buf55  # reuse
        # Topologically Sorted Source Nodes: [attn_output_11, hidden_states_11, to_24, pow_6, variance_5, add_17, rsqrt_5, mul_30, hidden_5, hidden_states_12, linear_24], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf75, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg43_1, (1024, 4096), (1, 1024), 0), out=buf76)
        del arg43_1
        buf77 = buf54; del buf54  # reuse
        # Topologically Sorted Source Nodes: [linear_25], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf75, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg44_1, (1024, 4096), (1, 1024), 0), out=buf77)
        del arg44_1
        buf78 = reinterpret_tensor(buf76, (2, 11, 4096), (45056, 4096, 1), 0); del buf76  # reuse
        # Topologically Sorted Source Nodes: [linear_24, silu_4, linear_25, mul_32], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf78, buf77, 90112, stream=stream0)
        del buf77
        buf79 = reinterpret_tensor(buf75, (22, 1024), (1024, 1), 0); del buf75  # reuse
        # Topologically Sorted Source Nodes: [linear_24, silu_4, linear_25, mul_32, hidden_states_13], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf78, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg45_1, (4096, 1024), (1, 4096), 0), out=buf79)
        del arg45_1
        buf81 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, to_26, pow_7, variance_6, add_19, rsqrt_6, mul_33, hidden_6, hidden_states_15], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf57, buf73, buf79, arg46_1, buf81, 22, 1024, stream=stream0)
        del arg46_1
        buf82 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf81, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg47_1, (1024, 2048), (1, 1024), 0), out=buf82)
        del arg47_1
        buf83 = reinterpret_tensor(buf65, (22, 256), (256, 1), 0); del buf65  # reuse
        # Topologically Sorted Source Nodes: [key_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf81, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg48_1, (1024, 256), (1, 1024), 0), out=buf83)
        del arg48_1
        buf84 = reinterpret_tensor(buf64, (22, 256), (256, 1), 0); del buf64  # reuse
        # Topologically Sorted Source Nodes: [value_states_9], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf81, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg49_1, (1024, 256), (1, 1024), 0), out=buf84)
        del arg49_1
        buf85 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_12, view_10, query_states_13, q_3, chunk_6, key_states_12, view_11, key_states_13, k_3, chunk_7, mul_35, neg_6, cat_9, mul_36, q_embed_3, query_states_14, query_states_15, mul_37, neg_7, cat_10, mul_38, k_embed_3, key_states_14, key_states_15, value_states_9, view_12, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf82, arg17_1, arg18_1, buf85, 45056, stream=stream0)
        del buf82
        buf86 = reinterpret_tensor(buf62, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf62  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_12, view_10, query_states_13, q_3, chunk_6, key_states_12, view_11, key_states_13, k_3, chunk_7, mul_35, neg_6, cat_9, mul_36, q_embed_3, query_states_14, query_states_15, mul_37, neg_7, cat_10, mul_38, k_embed_3, key_states_14, key_states_15, value_states_9, view_12, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf83, arg17_1, arg18_1, buf86, 5632, stream=stream0)
        buf87 = reinterpret_tensor(buf83, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf83  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_12, view_10, query_states_13, q_3, chunk_6, key_states_12, view_11, key_states_13, k_3, chunk_7, mul_35, neg_6, cat_9, mul_36, q_embed_3, query_states_14, query_states_15, mul_37, neg_7, cat_10, mul_38, k_embed_3, key_states_14, key_states_15, value_states_9, view_12, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf84, buf87, 5632, stream=stream0)
        del buf84
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_12, view_10, query_states_13, q_3, chunk_6, key_states_12, view_11, key_states_13, k_3, chunk_7, mul_35, neg_6, cat_9, mul_36, q_embed_3, query_states_14, query_states_15, mul_37, neg_7, cat_10, mul_38, k_embed_3, key_states_14, key_states_15, value_states_9, view_12, value_states_10, value_states_11, attn_output_12], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf88 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf85, buf86, buf87, scale=0.08838834764831843)
        del buf86
        del buf87
        buf89 = buf88[0]
        assert_size_stride(buf89, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf89, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf88
        buf94 = reinterpret_tensor(buf85, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf85  # reuse
        # Topologically Sorted Source Nodes: [transpose_17, attn_output_13], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf89, buf94, 45056, stream=stream0)
        del buf89
        buf95 = reinterpret_tensor(buf81, (22, 1024), (1024, 1), 0); del buf81  # reuse
        # Topologically Sorted Source Nodes: [transpose_17, attn_output_13, attn_output_14, attn_output_15], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf94, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg50_1, (2048, 1024), (1, 2048), 0), out=buf95)
        del arg50_1
        del buf94
        buf97 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, attn_output_15, hidden_states_16, to_32, pow_8, variance_7, add_23, rsqrt_7, mul_39, hidden_7, hidden_states_17], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf57, buf73, buf79, buf95, arg51_1, buf97, 22, 1024, stream=stream0)
        del arg51_1
        buf98 = reinterpret_tensor(buf78, (22, 4096), (4096, 1), 0); del buf78  # reuse
        # Topologically Sorted Source Nodes: [linear_31], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf97, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg52_1, (1024, 4096), (1, 1024), 0), out=buf98)
        del arg52_1
        buf99 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_32], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf97, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg53_1, (1024, 4096), (1, 1024), 0), out=buf99)
        del arg53_1
        del buf97
        buf100 = reinterpret_tensor(buf98, (2, 11, 4096), (45056, 4096, 1), 0); del buf98  # reuse
        # Topologically Sorted Source Nodes: [linear_31, silu_5, linear_32, mul_41], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf100, buf99, 90112, stream=stream0)
        buf101 = empty_strided_cuda((22, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_31, silu_5, linear_32, mul_41, hidden_states_18], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf100, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg54_1, (4096, 1024), (1, 4096), 0), out=buf101)
        del arg54_1
        buf102 = buf57; del buf57  # reuse
        buf104 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_11, hidden_states_11, hidden_states_13, hidden_states_14, attn_output_15, hidden_states_16, hidden_states_18, hidden_states_19, to_34, pow_9, variance_8, add_25, rsqrt_8, mul_42, hidden_8, hidden_states_20], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf102, buf73, buf79, buf95, buf101, arg55_1, buf104, 22, 1024, stream=stream0)
        del arg55_1
        del buf101
        del buf73
        del buf79
        buf105 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_34, pow_9, variance_8, add_25, rsqrt_8, mul_42, hidden_8, hidden_states_20, query_states_16], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf104, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg56_1, (1024, 2048), (1, 1024), 0), out=buf105)
        del arg56_1
        buf106 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_16], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf104, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg57_1, (1024, 256), (1, 1024), 0), out=buf106)
        del arg57_1
        buf107 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_12], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf104, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg58_1, (1024, 256), (1, 1024), 0), out=buf107)
        del arg58_1
        buf108 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_16, view_13, query_states_17, q_4, chunk_8, key_states_16, view_14, key_states_17, k_4, chunk_9, mul_44, neg_8, cat_11, mul_45, q_embed_4, query_states_18, query_states_19, mul_46, neg_9, cat_12, mul_47, k_embed_4, key_states_18, key_states_19, value_states_12, view_15, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf105, arg17_1, arg18_1, buf108, 45056, stream=stream0)
        del buf105
        buf109 = empty_strided_cuda((2, 2, 11, 128), (2816, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_16, view_13, query_states_17, q_4, chunk_8, key_states_16, view_14, key_states_17, k_4, chunk_9, mul_44, neg_8, cat_11, mul_45, q_embed_4, query_states_18, query_states_19, mul_46, neg_9, cat_12, mul_47, k_embed_4, key_states_18, key_states_19, value_states_12, view_15, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf106, arg17_1, arg18_1, buf109, 5632, stream=stream0)
        buf110 = reinterpret_tensor(buf106, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf106  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_16, view_13, query_states_17, q_4, chunk_8, key_states_16, view_14, key_states_17, k_4, chunk_9, mul_44, neg_8, cat_11, mul_45, q_embed_4, query_states_18, query_states_19, mul_46, neg_9, cat_12, mul_47, k_embed_4, key_states_18, key_states_19, value_states_12, view_15, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf107, buf110, 5632, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_16, view_13, query_states_17, q_4, chunk_8, key_states_16, view_14, key_states_17, k_4, chunk_9, mul_44, neg_8, cat_11, mul_45, q_embed_4, query_states_18, query_states_19, mul_46, neg_9, cat_12, mul_47, k_embed_4, key_states_18, key_states_19, value_states_12, view_15, value_states_13, value_states_14, attn_output_16], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf111 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf108, buf109, buf110, scale=0.08838834764831843)
        buf112 = buf111[0]
        assert_size_stride(buf112, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf112, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf111
        buf117 = reinterpret_tensor(buf108, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf108  # reuse
        # Topologically Sorted Source Nodes: [transpose_21, attn_output_17], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf112, buf117, 45056, stream=stream0)
        del buf112
        buf118 = reinterpret_tensor(buf104, (22, 1024), (1024, 1), 0); del buf104  # reuse
        # Topologically Sorted Source Nodes: [transpose_21, attn_output_17, attn_output_18, attn_output_19], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf117, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg59_1, (2048, 1024), (1, 2048), 0), out=buf118)
        del arg59_1
        del buf117
        buf120 = reinterpret_tensor(buf95, (2, 11, 1024), (11264, 1024, 1), 0); del buf95  # reuse
        # Topologically Sorted Source Nodes: [attn_output_19, hidden_states_21, to_40, pow_10, variance_9, add_29, rsqrt_9, mul_48, hidden_9, hidden_states_22], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf102, buf118, arg60_1, buf120, 22, 1024, stream=stream0)
        del arg60_1
        buf121 = reinterpret_tensor(buf100, (22, 4096), (4096, 1), 0); del buf100  # reuse
        # Topologically Sorted Source Nodes: [attn_output_19, hidden_states_21, to_40, pow_10, variance_9, add_29, rsqrt_9, mul_48, hidden_9, hidden_states_22, linear_38], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf120, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg61_1, (1024, 4096), (1, 1024), 0), out=buf121)
        del arg61_1
        buf122 = buf99; del buf99  # reuse
        # Topologically Sorted Source Nodes: [linear_39], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf120, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg62_1, (1024, 4096), (1, 1024), 0), out=buf122)
        del arg62_1
        buf123 = reinterpret_tensor(buf121, (2, 11, 4096), (45056, 4096, 1), 0); del buf121  # reuse
        # Topologically Sorted Source Nodes: [linear_38, silu_6, linear_39, mul_50], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf123, buf122, 90112, stream=stream0)
        del buf122
        buf124 = reinterpret_tensor(buf120, (22, 1024), (1024, 1), 0); del buf120  # reuse
        # Topologically Sorted Source Nodes: [linear_38, silu_6, linear_39, mul_50, hidden_states_23], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf123, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg63_1, (4096, 1024), (1, 4096), 0), out=buf124)
        del arg63_1
        buf126 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, to_42, pow_11, variance_10, add_31, rsqrt_10, mul_51, hidden_10, hidden_states_25], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf102, buf118, buf124, arg64_1, buf126, 22, 1024, stream=stream0)
        del arg64_1
        buf127 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf126, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg65_1, (1024, 2048), (1, 1024), 0), out=buf127)
        del arg65_1
        buf128 = reinterpret_tensor(buf110, (22, 256), (256, 1), 0); del buf110  # reuse
        # Topologically Sorted Source Nodes: [key_states_20], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf126, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg66_1, (1024, 256), (1, 1024), 0), out=buf128)
        del arg66_1
        buf129 = reinterpret_tensor(buf109, (22, 256), (256, 1), 0); del buf109  # reuse
        # Topologically Sorted Source Nodes: [value_states_15], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf126, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg67_1, (1024, 256), (1, 1024), 0), out=buf129)
        del arg67_1
        buf130 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_20, view_16, query_states_21, q_5, chunk_10, key_states_20, view_17, key_states_21, k_5, chunk_11, mul_53, neg_10, cat_13, mul_54, q_embed_5, query_states_22, query_states_23, mul_55, neg_11, cat_14, mul_56, k_embed_5, key_states_22, key_states_23, value_states_15, view_18, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf127, arg17_1, arg18_1, buf130, 45056, stream=stream0)
        del buf127
        buf131 = reinterpret_tensor(buf107, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf107  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_20, view_16, query_states_21, q_5, chunk_10, key_states_20, view_17, key_states_21, k_5, chunk_11, mul_53, neg_10, cat_13, mul_54, q_embed_5, query_states_22, query_states_23, mul_55, neg_11, cat_14, mul_56, k_embed_5, key_states_22, key_states_23, value_states_15, view_18, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf128, arg17_1, arg18_1, buf131, 5632, stream=stream0)
        buf132 = reinterpret_tensor(buf128, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf128  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_20, view_16, query_states_21, q_5, chunk_10, key_states_20, view_17, key_states_21, k_5, chunk_11, mul_53, neg_10, cat_13, mul_54, q_embed_5, query_states_22, query_states_23, mul_55, neg_11, cat_14, mul_56, k_embed_5, key_states_22, key_states_23, value_states_15, view_18, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf129, buf132, 5632, stream=stream0)
        del buf129
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_20, view_16, query_states_21, q_5, chunk_10, key_states_20, view_17, key_states_21, k_5, chunk_11, mul_53, neg_10, cat_13, mul_54, q_embed_5, query_states_22, query_states_23, mul_55, neg_11, cat_14, mul_56, k_embed_5, key_states_22, key_states_23, value_states_15, view_18, value_states_16, value_states_17, attn_output_20], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf133 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf130, buf131, buf132, scale=0.08838834764831843)
        del buf131
        del buf132
        buf134 = buf133[0]
        assert_size_stride(buf134, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf134, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf133
        buf139 = reinterpret_tensor(buf130, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf130  # reuse
        # Topologically Sorted Source Nodes: [transpose_25, attn_output_21], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf134, buf139, 45056, stream=stream0)
        del buf134
        buf140 = reinterpret_tensor(buf126, (22, 1024), (1024, 1), 0); del buf126  # reuse
        # Topologically Sorted Source Nodes: [transpose_25, attn_output_21, attn_output_22, attn_output_23], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf139, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg68_1, (2048, 1024), (1, 2048), 0), out=buf140)
        del arg68_1
        del buf139
        buf142 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, attn_output_23, hidden_states_26, to_48, pow_12, variance_11, add_35, rsqrt_11, mul_57, hidden_11, hidden_states_27], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf102, buf118, buf124, buf140, arg69_1, buf142, 22, 1024, stream=stream0)
        del arg69_1
        buf143 = reinterpret_tensor(buf123, (22, 4096), (4096, 1), 0); del buf123  # reuse
        # Topologically Sorted Source Nodes: [linear_45], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf142, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg70_1, (1024, 4096), (1, 1024), 0), out=buf143)
        del arg70_1
        buf144 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_46], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf142, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg71_1, (1024, 4096), (1, 1024), 0), out=buf144)
        del arg71_1
        del buf142
        buf145 = reinterpret_tensor(buf143, (2, 11, 4096), (45056, 4096, 1), 0); del buf143  # reuse
        # Topologically Sorted Source Nodes: [linear_45, silu_7, linear_46, mul_59], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf145, buf144, 90112, stream=stream0)
        buf146 = empty_strided_cuda((22, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_45, silu_7, linear_46, mul_59, hidden_states_28], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf145, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg72_1, (4096, 1024), (1, 4096), 0), out=buf146)
        del arg72_1
        buf147 = buf102; del buf102  # reuse
        buf149 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_19, hidden_states_21, hidden_states_23, hidden_states_24, attn_output_23, hidden_states_26, hidden_states_28, hidden_states_29, to_50, pow_13, variance_12, add_37, rsqrt_12, mul_60, hidden_12, hidden_states_30], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf147, buf118, buf124, buf140, buf146, arg73_1, buf149, 22, 1024, stream=stream0)
        del arg73_1
        del buf118
        del buf124
        del buf140
        buf150 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_50, pow_13, variance_12, add_37, rsqrt_12, mul_60, hidden_12, hidden_states_30, query_states_24], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf149, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg74_1, (1024, 2048), (1, 1024), 0), out=buf150)
        del arg74_1
        buf151 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf149, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg75_1, (1024, 256), (1, 1024), 0), out=buf151)
        del arg75_1
        buf152 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_18], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf149, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg76_1, (1024, 256), (1, 1024), 0), out=buf152)
        del arg76_1
        buf153 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_24, view_19, query_states_25, q_6, chunk_12, key_states_24, view_20, key_states_25, k_6, chunk_13, mul_62, neg_12, cat_15, mul_63, q_embed_6, query_states_26, query_states_27, mul_64, neg_13, cat_16, mul_65, k_embed_6, key_states_26, key_states_27, value_states_18, view_21, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf150, arg17_1, arg18_1, buf153, 45056, stream=stream0)
        del buf150
        buf154 = empty_strided_cuda((2, 2, 11, 128), (2816, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_24, view_19, query_states_25, q_6, chunk_12, key_states_24, view_20, key_states_25, k_6, chunk_13, mul_62, neg_12, cat_15, mul_63, q_embed_6, query_states_26, query_states_27, mul_64, neg_13, cat_16, mul_65, k_embed_6, key_states_26, key_states_27, value_states_18, view_21, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf151, arg17_1, arg18_1, buf154, 5632, stream=stream0)
        buf155 = reinterpret_tensor(buf151, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf151  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_24, view_19, query_states_25, q_6, chunk_12, key_states_24, view_20, key_states_25, k_6, chunk_13, mul_62, neg_12, cat_15, mul_63, q_embed_6, query_states_26, query_states_27, mul_64, neg_13, cat_16, mul_65, k_embed_6, key_states_26, key_states_27, value_states_18, view_21, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf152, buf155, 5632, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_24, view_19, query_states_25, q_6, chunk_12, key_states_24, view_20, key_states_25, k_6, chunk_13, mul_62, neg_12, cat_15, mul_63, q_embed_6, query_states_26, query_states_27, mul_64, neg_13, cat_16, mul_65, k_embed_6, key_states_26, key_states_27, value_states_18, view_21, value_states_19, value_states_20, attn_output_24], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf156 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf153, buf154, buf155, scale=0.08838834764831843)
        buf157 = buf156[0]
        assert_size_stride(buf157, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf157, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf156
        buf162 = reinterpret_tensor(buf153, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf153  # reuse
        # Topologically Sorted Source Nodes: [transpose_29, attn_output_25], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf157, buf162, 45056, stream=stream0)
        del buf157
        buf163 = reinterpret_tensor(buf149, (22, 1024), (1024, 1), 0); del buf149  # reuse
        # Topologically Sorted Source Nodes: [transpose_29, attn_output_25, attn_output_26, attn_output_27], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf162, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg77_1, (2048, 1024), (1, 2048), 0), out=buf163)
        del arg77_1
        del buf162
        buf165 = reinterpret_tensor(buf146, (2, 11, 1024), (11264, 1024, 1), 0); del buf146  # reuse
        # Topologically Sorted Source Nodes: [attn_output_27, hidden_states_31, to_56, pow_14, variance_13, add_41, rsqrt_13, mul_66, hidden_13, hidden_states_32], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf147, buf163, arg78_1, buf165, 22, 1024, stream=stream0)
        del arg78_1
        buf166 = reinterpret_tensor(buf145, (22, 4096), (4096, 1), 0); del buf145  # reuse
        # Topologically Sorted Source Nodes: [attn_output_27, hidden_states_31, to_56, pow_14, variance_13, add_41, rsqrt_13, mul_66, hidden_13, hidden_states_32, linear_52], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf165, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg79_1, (1024, 4096), (1, 1024), 0), out=buf166)
        del arg79_1
        buf167 = buf144; del buf144  # reuse
        # Topologically Sorted Source Nodes: [linear_53], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf165, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg80_1, (1024, 4096), (1, 1024), 0), out=buf167)
        del arg80_1
        buf168 = reinterpret_tensor(buf166, (2, 11, 4096), (45056, 4096, 1), 0); del buf166  # reuse
        # Topologically Sorted Source Nodes: [linear_52, silu_8, linear_53, mul_68], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf168, buf167, 90112, stream=stream0)
        del buf167
        buf169 = reinterpret_tensor(buf165, (22, 1024), (1024, 1), 0); del buf165  # reuse
        # Topologically Sorted Source Nodes: [linear_52, silu_8, linear_53, mul_68, hidden_states_33], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf168, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg81_1, (4096, 1024), (1, 4096), 0), out=buf169)
        del arg81_1
        buf171 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, to_58, pow_15, variance_14, add_43, rsqrt_14, mul_69, hidden_14, hidden_states_35], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf147, buf163, buf169, arg82_1, buf171, 22, 1024, stream=stream0)
        del arg82_1
        buf172 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_28], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf171, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg83_1, (1024, 2048), (1, 1024), 0), out=buf172)
        del arg83_1
        buf173 = reinterpret_tensor(buf155, (22, 256), (256, 1), 0); del buf155  # reuse
        # Topologically Sorted Source Nodes: [key_states_28], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf171, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg84_1, (1024, 256), (1, 1024), 0), out=buf173)
        del arg84_1
        buf174 = reinterpret_tensor(buf154, (22, 256), (256, 1), 0); del buf154  # reuse
        # Topologically Sorted Source Nodes: [value_states_21], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf171, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg85_1, (1024, 256), (1, 1024), 0), out=buf174)
        del arg85_1
        buf175 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_28, view_22, query_states_29, q_7, chunk_14, key_states_28, view_23, key_states_29, k_7, chunk_15, mul_71, neg_14, cat_17, mul_72, q_embed_7, query_states_30, query_states_31, mul_73, neg_15, cat_18, mul_74, k_embed_7, key_states_30, key_states_31, value_states_21, view_24, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf172, arg17_1, arg18_1, buf175, 45056, stream=stream0)
        del buf172
        buf176 = reinterpret_tensor(buf152, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf152  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_28, view_22, query_states_29, q_7, chunk_14, key_states_28, view_23, key_states_29, k_7, chunk_15, mul_71, neg_14, cat_17, mul_72, q_embed_7, query_states_30, query_states_31, mul_73, neg_15, cat_18, mul_74, k_embed_7, key_states_30, key_states_31, value_states_21, view_24, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf173, arg17_1, arg18_1, buf176, 5632, stream=stream0)
        buf177 = reinterpret_tensor(buf173, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf173  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_28, view_22, query_states_29, q_7, chunk_14, key_states_28, view_23, key_states_29, k_7, chunk_15, mul_71, neg_14, cat_17, mul_72, q_embed_7, query_states_30, query_states_31, mul_73, neg_15, cat_18, mul_74, k_embed_7, key_states_30, key_states_31, value_states_21, view_24, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf174, buf177, 5632, stream=stream0)
        del buf174
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_28, view_22, query_states_29, q_7, chunk_14, key_states_28, view_23, key_states_29, k_7, chunk_15, mul_71, neg_14, cat_17, mul_72, q_embed_7, query_states_30, query_states_31, mul_73, neg_15, cat_18, mul_74, k_embed_7, key_states_30, key_states_31, value_states_21, view_24, value_states_22, value_states_23, attn_output_28], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf178 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf175, buf176, buf177, scale=0.08838834764831843)
        del buf176
        del buf177
        buf179 = buf178[0]
        assert_size_stride(buf179, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf179, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf178
        buf184 = reinterpret_tensor(buf175, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf175  # reuse
        # Topologically Sorted Source Nodes: [transpose_33, attn_output_29], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf179, buf184, 45056, stream=stream0)
        del buf179
        buf185 = reinterpret_tensor(buf171, (22, 1024), (1024, 1), 0); del buf171  # reuse
        # Topologically Sorted Source Nodes: [transpose_33, attn_output_29, attn_output_30, attn_output_31], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf184, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg86_1, (2048, 1024), (1, 2048), 0), out=buf185)
        del arg86_1
        del buf184
        buf187 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, attn_output_31, hidden_states_36, to_64, pow_16, variance_15, add_47, rsqrt_15, mul_75, hidden_15, hidden_states_37], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf147, buf163, buf169, buf185, arg87_1, buf187, 22, 1024, stream=stream0)
        del arg87_1
        buf188 = reinterpret_tensor(buf168, (22, 4096), (4096, 1), 0); del buf168  # reuse
        # Topologically Sorted Source Nodes: [linear_59], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf187, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg88_1, (1024, 4096), (1, 1024), 0), out=buf188)
        del arg88_1
        buf189 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_60], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf187, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg89_1, (1024, 4096), (1, 1024), 0), out=buf189)
        del arg89_1
        del buf187
        buf190 = reinterpret_tensor(buf188, (2, 11, 4096), (45056, 4096, 1), 0); del buf188  # reuse
        # Topologically Sorted Source Nodes: [linear_59, silu_9, linear_60, mul_77], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf190, buf189, 90112, stream=stream0)
        buf191 = empty_strided_cuda((22, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_59, silu_9, linear_60, mul_77, hidden_states_38], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf190, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg90_1, (4096, 1024), (1, 4096), 0), out=buf191)
        del arg90_1
        buf192 = buf147; del buf147  # reuse
        buf194 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_27, hidden_states_31, hidden_states_33, hidden_states_34, attn_output_31, hidden_states_36, hidden_states_38, hidden_states_39, to_66, pow_17, variance_16, add_49, rsqrt_16, mul_78, hidden_16, hidden_states_40], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf192, buf163, buf169, buf185, buf191, arg91_1, buf194, 22, 1024, stream=stream0)
        del arg91_1
        del buf163
        del buf169
        del buf185
        buf195 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_66, pow_17, variance_16, add_49, rsqrt_16, mul_78, hidden_16, hidden_states_40, query_states_32], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf194, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg92_1, (1024, 2048), (1, 1024), 0), out=buf195)
        del arg92_1
        buf196 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_32], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf194, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg93_1, (1024, 256), (1, 1024), 0), out=buf196)
        del arg93_1
        buf197 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_24], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf194, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg94_1, (1024, 256), (1, 1024), 0), out=buf197)
        del arg94_1
        buf198 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_32, view_25, query_states_33, q_8, chunk_16, key_states_32, view_26, key_states_33, k_8, chunk_17, mul_80, neg_16, cat_19, mul_81, q_embed_8, query_states_34, query_states_35, mul_82, neg_17, cat_20, mul_83, k_embed_8, key_states_34, key_states_35, value_states_24, view_27, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf195, arg17_1, arg18_1, buf198, 45056, stream=stream0)
        del buf195
        buf199 = empty_strided_cuda((2, 2, 11, 128), (2816, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_32, view_25, query_states_33, q_8, chunk_16, key_states_32, view_26, key_states_33, k_8, chunk_17, mul_80, neg_16, cat_19, mul_81, q_embed_8, query_states_34, query_states_35, mul_82, neg_17, cat_20, mul_83, k_embed_8, key_states_34, key_states_35, value_states_24, view_27, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf196, arg17_1, arg18_1, buf199, 5632, stream=stream0)
        buf200 = reinterpret_tensor(buf196, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf196  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_32, view_25, query_states_33, q_8, chunk_16, key_states_32, view_26, key_states_33, k_8, chunk_17, mul_80, neg_16, cat_19, mul_81, q_embed_8, query_states_34, query_states_35, mul_82, neg_17, cat_20, mul_83, k_embed_8, key_states_34, key_states_35, value_states_24, view_27, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf197, buf200, 5632, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_32, view_25, query_states_33, q_8, chunk_16, key_states_32, view_26, key_states_33, k_8, chunk_17, mul_80, neg_16, cat_19, mul_81, q_embed_8, query_states_34, query_states_35, mul_82, neg_17, cat_20, mul_83, k_embed_8, key_states_34, key_states_35, value_states_24, view_27, value_states_25, value_states_26, attn_output_32], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf201 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf198, buf199, buf200, scale=0.08838834764831843)
        buf202 = buf201[0]
        assert_size_stride(buf202, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf202, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf201
        buf207 = reinterpret_tensor(buf198, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf198  # reuse
        # Topologically Sorted Source Nodes: [transpose_37, attn_output_33], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf202, buf207, 45056, stream=stream0)
        del buf202
        buf208 = reinterpret_tensor(buf194, (22, 1024), (1024, 1), 0); del buf194  # reuse
        # Topologically Sorted Source Nodes: [transpose_37, attn_output_33, attn_output_34, attn_output_35], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf207, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg95_1, (2048, 1024), (1, 2048), 0), out=buf208)
        del arg95_1
        del buf207
        buf210 = reinterpret_tensor(buf191, (2, 11, 1024), (11264, 1024, 1), 0); del buf191  # reuse
        # Topologically Sorted Source Nodes: [attn_output_35, hidden_states_41, to_72, pow_18, variance_17, add_53, rsqrt_17, mul_84, hidden_17, hidden_states_42], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf192, buf208, arg96_1, buf210, 22, 1024, stream=stream0)
        del arg96_1
        buf211 = reinterpret_tensor(buf190, (22, 4096), (4096, 1), 0); del buf190  # reuse
        # Topologically Sorted Source Nodes: [attn_output_35, hidden_states_41, to_72, pow_18, variance_17, add_53, rsqrt_17, mul_84, hidden_17, hidden_states_42, linear_66], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf210, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg97_1, (1024, 4096), (1, 1024), 0), out=buf211)
        del arg97_1
        buf212 = buf189; del buf189  # reuse
        # Topologically Sorted Source Nodes: [linear_67], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf210, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg98_1, (1024, 4096), (1, 1024), 0), out=buf212)
        del arg98_1
        buf213 = reinterpret_tensor(buf211, (2, 11, 4096), (45056, 4096, 1), 0); del buf211  # reuse
        # Topologically Sorted Source Nodes: [linear_66, silu_10, linear_67, mul_86], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf213, buf212, 90112, stream=stream0)
        del buf212
        buf214 = reinterpret_tensor(buf210, (22, 1024), (1024, 1), 0); del buf210  # reuse
        # Topologically Sorted Source Nodes: [linear_66, silu_10, linear_67, mul_86, hidden_states_43], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf213, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg99_1, (4096, 1024), (1, 4096), 0), out=buf214)
        del arg99_1
        buf216 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, to_74, pow_19, variance_18, add_55, rsqrt_18, mul_87, hidden_18, hidden_states_45], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf192, buf208, buf214, arg100_1, buf216, 22, 1024, stream=stream0)
        del arg100_1
        buf217 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_36], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf216, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg101_1, (1024, 2048), (1, 1024), 0), out=buf217)
        del arg101_1
        buf218 = reinterpret_tensor(buf200, (22, 256), (256, 1), 0); del buf200  # reuse
        # Topologically Sorted Source Nodes: [key_states_36], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf216, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg102_1, (1024, 256), (1, 1024), 0), out=buf218)
        del arg102_1
        buf219 = reinterpret_tensor(buf199, (22, 256), (256, 1), 0); del buf199  # reuse
        # Topologically Sorted Source Nodes: [value_states_27], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf216, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg103_1, (1024, 256), (1, 1024), 0), out=buf219)
        del arg103_1
        buf220 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_36, view_28, query_states_37, q_9, chunk_18, key_states_36, view_29, key_states_37, k_9, chunk_19, mul_89, neg_18, cat_21, mul_90, q_embed_9, query_states_38, query_states_39, mul_91, neg_19, cat_22, mul_92, k_embed_9, key_states_38, key_states_39, value_states_27, view_30, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf217, arg17_1, arg18_1, buf220, 45056, stream=stream0)
        del buf217
        buf221 = reinterpret_tensor(buf197, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf197  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_36, view_28, query_states_37, q_9, chunk_18, key_states_36, view_29, key_states_37, k_9, chunk_19, mul_89, neg_18, cat_21, mul_90, q_embed_9, query_states_38, query_states_39, mul_91, neg_19, cat_22, mul_92, k_embed_9, key_states_38, key_states_39, value_states_27, view_30, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf218, arg17_1, arg18_1, buf221, 5632, stream=stream0)
        buf222 = reinterpret_tensor(buf218, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf218  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_36, view_28, query_states_37, q_9, chunk_18, key_states_36, view_29, key_states_37, k_9, chunk_19, mul_89, neg_18, cat_21, mul_90, q_embed_9, query_states_38, query_states_39, mul_91, neg_19, cat_22, mul_92, k_embed_9, key_states_38, key_states_39, value_states_27, view_30, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf219, buf222, 5632, stream=stream0)
        del buf219
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_36, view_28, query_states_37, q_9, chunk_18, key_states_36, view_29, key_states_37, k_9, chunk_19, mul_89, neg_18, cat_21, mul_90, q_embed_9, query_states_38, query_states_39, mul_91, neg_19, cat_22, mul_92, k_embed_9, key_states_38, key_states_39, value_states_27, view_30, value_states_28, value_states_29, attn_output_36], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf223 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf220, buf221, buf222, scale=0.08838834764831843)
        del buf221
        del buf222
        buf224 = buf223[0]
        assert_size_stride(buf224, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf224, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf223
        buf229 = reinterpret_tensor(buf220, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf220  # reuse
        # Topologically Sorted Source Nodes: [transpose_41, attn_output_37], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf224, buf229, 45056, stream=stream0)
        del buf224
        buf230 = reinterpret_tensor(buf216, (22, 1024), (1024, 1), 0); del buf216  # reuse
        # Topologically Sorted Source Nodes: [transpose_41, attn_output_37, attn_output_38, attn_output_39], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf229, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg104_1, (2048, 1024), (1, 2048), 0), out=buf230)
        del arg104_1
        del buf229
        buf232 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, attn_output_39, hidden_states_46, to_80, pow_20, variance_19, add_59, rsqrt_19, mul_93, hidden_19, hidden_states_47], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf192, buf208, buf214, buf230, arg105_1, buf232, 22, 1024, stream=stream0)
        del arg105_1
        buf233 = reinterpret_tensor(buf213, (22, 4096), (4096, 1), 0); del buf213  # reuse
        # Topologically Sorted Source Nodes: [linear_73], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf232, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg106_1, (1024, 4096), (1, 1024), 0), out=buf233)
        del arg106_1
        buf234 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_74], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf232, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg107_1, (1024, 4096), (1, 1024), 0), out=buf234)
        del arg107_1
        del buf232
        buf235 = reinterpret_tensor(buf233, (2, 11, 4096), (45056, 4096, 1), 0); del buf233  # reuse
        # Topologically Sorted Source Nodes: [linear_73, silu_11, linear_74, mul_95], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf235, buf234, 90112, stream=stream0)
        buf236 = empty_strided_cuda((22, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_73, silu_11, linear_74, mul_95, hidden_states_48], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf235, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg108_1, (4096, 1024), (1, 4096), 0), out=buf236)
        del arg108_1
        buf237 = buf192; del buf192  # reuse
        buf239 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_35, hidden_states_41, hidden_states_43, hidden_states_44, attn_output_39, hidden_states_46, hidden_states_48, hidden_states_49, to_82, pow_21, variance_20, add_61, rsqrt_20, mul_96, hidden_20, hidden_states_50], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_12.run(buf237, buf208, buf214, buf230, buf236, arg109_1, buf239, 22, 1024, stream=stream0)
        del arg109_1
        del buf208
        del buf214
        del buf230
        buf240 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_82, pow_21, variance_20, add_61, rsqrt_20, mul_96, hidden_20, hidden_states_50, query_states_40], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf239, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg110_1, (1024, 2048), (1, 1024), 0), out=buf240)
        del arg110_1
        buf241 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [key_states_40], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf239, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg111_1, (1024, 256), (1, 1024), 0), out=buf241)
        del arg111_1
        buf242 = empty_strided_cuda((22, 256), (256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [value_states_30], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf239, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg112_1, (1024, 256), (1, 1024), 0), out=buf242)
        del arg112_1
        buf243 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_40, view_31, query_states_41, q_10, chunk_20, key_states_40, view_32, key_states_41, k_10, chunk_21, mul_98, neg_20, cat_23, mul_99, q_embed_10, query_states_42, query_states_43, mul_100, neg_21, cat_24, mul_101, k_embed_10, key_states_42, key_states_43, value_states_30, view_33, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf240, arg17_1, arg18_1, buf243, 45056, stream=stream0)
        del buf240
        buf244 = empty_strided_cuda((2, 2, 11, 128), (2816, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_40, view_31, query_states_41, q_10, chunk_20, key_states_40, view_32, key_states_41, k_10, chunk_21, mul_98, neg_20, cat_23, mul_99, q_embed_10, query_states_42, query_states_43, mul_100, neg_21, cat_24, mul_101, k_embed_10, key_states_42, key_states_43, value_states_30, view_33, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf241, arg17_1, arg18_1, buf244, 5632, stream=stream0)
        buf245 = reinterpret_tensor(buf241, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf241  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_40, view_31, query_states_41, q_10, chunk_20, key_states_40, view_32, key_states_41, k_10, chunk_21, mul_98, neg_20, cat_23, mul_99, q_embed_10, query_states_42, query_states_43, mul_100, neg_21, cat_24, mul_101, k_embed_10, key_states_42, key_states_43, value_states_30, view_33, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf242, buf245, 5632, stream=stream0)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_40, view_31, query_states_41, q_10, chunk_20, key_states_40, view_32, key_states_41, k_10, chunk_21, mul_98, neg_20, cat_23, mul_99, q_embed_10, query_states_42, query_states_43, mul_100, neg_21, cat_24, mul_101, k_embed_10, key_states_42, key_states_43, value_states_30, view_33, value_states_31, value_states_32, attn_output_40], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf246 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf243, buf244, buf245, scale=0.08838834764831843)
        buf247 = buf246[0]
        assert_size_stride(buf247, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf247, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf246
        buf252 = reinterpret_tensor(buf243, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf243  # reuse
        # Topologically Sorted Source Nodes: [transpose_45, attn_output_41], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf247, buf252, 45056, stream=stream0)
        del buf247
        buf253 = reinterpret_tensor(buf239, (22, 1024), (1024, 1), 0); del buf239  # reuse
        # Topologically Sorted Source Nodes: [transpose_45, attn_output_41, attn_output_42, attn_output_43], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf252, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg113_1, (2048, 1024), (1, 2048), 0), out=buf253)
        del arg113_1
        del buf252
        buf255 = reinterpret_tensor(buf236, (2, 11, 1024), (11264, 1024, 1), 0); del buf236  # reuse
        # Topologically Sorted Source Nodes: [attn_output_43, hidden_states_51, to_88, pow_22, variance_21, add_65, rsqrt_21, mul_102, hidden_21, hidden_states_52], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_8.run(buf237, buf253, arg114_1, buf255, 22, 1024, stream=stream0)
        del arg114_1
        buf256 = reinterpret_tensor(buf235, (22, 4096), (4096, 1), 0); del buf235  # reuse
        # Topologically Sorted Source Nodes: [attn_output_43, hidden_states_51, to_88, pow_22, variance_21, add_65, rsqrt_21, mul_102, hidden_21, hidden_states_52, linear_80], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf255, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg115_1, (1024, 4096), (1, 1024), 0), out=buf256)
        del arg115_1
        buf257 = buf234; del buf234  # reuse
        # Topologically Sorted Source Nodes: [linear_81], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf255, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg116_1, (1024, 4096), (1, 1024), 0), out=buf257)
        del arg116_1
        buf258 = reinterpret_tensor(buf256, (2, 11, 4096), (45056, 4096, 1), 0); del buf256  # reuse
        # Topologically Sorted Source Nodes: [linear_80, silu_12, linear_81, mul_104], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf258, buf257, 90112, stream=stream0)
        del buf257
        buf259 = reinterpret_tensor(buf255, (22, 1024), (1024, 1), 0); del buf255  # reuse
        # Topologically Sorted Source Nodes: [linear_80, silu_12, linear_81, mul_104, hidden_states_53], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf258, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg117_1, (4096, 1024), (1, 4096), 0), out=buf259)
        del arg117_1
        buf261 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, to_90, pow_23, variance_22, add_67, rsqrt_22, mul_105, hidden_22, hidden_states_55], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_10.run(buf237, buf253, buf259, arg118_1, buf261, 22, 1024, stream=stream0)
        del arg118_1
        buf262 = empty_strided_cuda((22, 2048), (2048, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [query_states_44], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf261, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg119_1, (1024, 2048), (1, 1024), 0), out=buf262)
        del arg119_1
        buf263 = reinterpret_tensor(buf245, (22, 256), (256, 1), 0); del buf245  # reuse
        # Topologically Sorted Source Nodes: [key_states_44], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf261, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg120_1, (1024, 256), (1, 1024), 0), out=buf263)
        del arg120_1
        buf264 = reinterpret_tensor(buf244, (22, 256), (256, 1), 0); del buf244  # reuse
        # Topologically Sorted Source Nodes: [value_states_33], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf261, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg121_1, (1024, 256), (1, 1024), 0), out=buf264)
        del arg121_1
        buf265 = empty_strided_cuda((2, 16, 11, 128), (22528, 1408, 128, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_44, view_34, query_states_45, q_11, chunk_22, key_states_44, view_35, key_states_45, k_11, chunk_23, mul_107, neg_22, cat_25, mul_108, q_embed_11, query_states_46, query_states_47, mul_109, neg_23, cat_26, mul_110, k_embed_11, key_states_46, key_states_47, value_states_33, view_36, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_4.run(buf262, arg17_1, arg18_1, buf265, 45056, stream=stream0)
        del buf262
        buf266 = reinterpret_tensor(buf242, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf242  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_44, view_34, query_states_45, q_11, chunk_22, key_states_44, view_35, key_states_45, k_11, chunk_23, mul_107, neg_22, cat_25, mul_108, q_embed_11, query_states_46, query_states_47, mul_109, neg_23, cat_26, mul_110, k_embed_11, key_states_46, key_states_47, value_states_33, view_36, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_5.run(buf263, arg17_1, arg18_1, buf266, 5632, stream=stream0)
        del arg17_1
        del arg18_1
        buf267 = reinterpret_tensor(buf263, (2, 2, 11, 128), (2816, 1408, 128, 1), 0); del buf263  # reuse
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_44, view_34, query_states_45, q_11, chunk_22, key_states_44, view_35, key_states_45, k_11, chunk_23, mul_107, neg_22, cat_25, mul_108, q_embed_11, query_states_46, query_states_47, mul_109, neg_23, cat_26, mul_110, k_embed_11, key_states_46, key_states_47, value_states_33, view_36, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        stream0 = get_raw_stream(0)
        triton_poi_fused__scaled_dot_product_flash_attention__to_copy__unsafe_view_add_arange_cat_clone_index_mul_neg_split_transpose_view_6.run(buf264, buf267, 5632, stream=stream0)
        del buf264
        # Topologically Sorted Source Nodes: [position_ids, cos_2, sin_2, query_states_44, view_34, query_states_45, q_11, chunk_22, key_states_44, view_35, key_states_45, k_11, chunk_23, mul_107, neg_22, cat_25, mul_108, q_embed_11, query_states_46, query_states_47, mul_109, neg_23, cat_26, mul_110, k_embed_11, key_states_46, key_states_47, value_states_33, view_36, value_states_34, value_states_35, attn_output_44], Original ATen: [aten.arange, aten.index, aten._unsafe_view, aten.view, aten.transpose, aten._to_copy, aten.split, aten.mul, aten.neg, aten.cat, aten.add, aten.clone, aten._scaled_dot_product_flash_attention]
        buf268 = torch.ops.aten._scaled_dot_product_flash_attention.default(buf265, buf266, buf267, scale=0.08838834764831843)
        del buf266
        del buf267
        buf269 = buf268[0]
        assert_size_stride(buf269, (2, 16, 11, 128), (22528, 1408, 128, 1), 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        assert_alignment(buf269, 16, 'torch.ops.aten._scaled_dot_product_flash_attention.default')
        del buf268
        buf274 = reinterpret_tensor(buf265, (2, 11, 16, 128), (22528, 2048, 128, 1), 0); del buf265  # reuse
        # Topologically Sorted Source Nodes: [transpose_49, attn_output_45], Original ATen: [aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused_clone_transpose_7.run(buf269, buf274, 45056, stream=stream0)
        del buf269
        buf275 = reinterpret_tensor(buf261, (22, 1024), (1024, 1), 0); del buf261  # reuse
        # Topologically Sorted Source Nodes: [transpose_49, attn_output_45, attn_output_46, attn_output_47], Original ATen: [aten.transpose, aten.clone, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf274, (22, 2048), (2048, 1), 0), reinterpret_tensor(arg122_1, (2048, 1024), (1, 2048), 0), out=buf275)
        del arg122_1
        del buf274
        buf277 = empty_strided_cuda((2, 11, 1024), (11264, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, attn_output_47, hidden_states_56, to_96, pow_24, variance_23, add_71, rsqrt_23, mul_111, hidden_23, hidden_states_57], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean, aten.rsqrt, aten.mul]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_mul_pow_rsqrt_11.run(buf237, buf253, buf259, buf275, arg123_1, buf277, 22, 1024, stream=stream0)
        del arg123_1
        buf278 = reinterpret_tensor(buf258, (22, 4096), (4096, 1), 0); del buf258  # reuse
        # Topologically Sorted Source Nodes: [linear_87], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf277, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg124_1, (1024, 4096), (1, 1024), 0), out=buf278)
        del arg124_1
        buf279 = empty_strided_cuda((22, 4096), (4096, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_88], Original ATen: [aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf277, (22, 1024), (1024, 1), 0), reinterpret_tensor(arg125_1, (1024, 4096), (1, 1024), 0), out=buf279)
        del arg125_1
        del buf277
        buf280 = reinterpret_tensor(buf278, (2, 11, 4096), (45056, 4096, 1), 0); del buf278  # reuse
        # Topologically Sorted Source Nodes: [linear_87, silu_13, linear_88, mul_113], Original ATen: [aten._unsafe_view, aten.silu, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_mul_silu_9.run(buf280, buf279, 90112, stream=stream0)
        del buf279
        buf281 = empty_strided_cuda((22, 1024), (1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [linear_87, silu_13, linear_88, mul_113, hidden_states_58], Original ATen: [aten._unsafe_view, aten.silu, aten.mul, aten.view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf280, (22, 4096), (4096, 1), 0), reinterpret_tensor(arg126_1, (4096, 1024), (1, 4096), 0), out=buf281)
        del arg126_1
        del buf280
        buf282 = buf237; del buf237  # reuse
        buf283 = empty_strided_cuda((2, 11, 1), (11, 1, 22), torch.float32)
        # Topologically Sorted Source Nodes: [attn_output_43, hidden_states_51, hidden_states_53, hidden_states_54, attn_output_47, hidden_states_56, hidden_states_58, hidden_states_59, to_98, pow_25, variance_24], Original ATen: [aten._unsafe_view, aten.add, aten._to_copy, aten.pow, aten.mean]
        stream0 = get_raw_stream(0)
        triton_per_fused__to_copy__unsafe_view_add_mean_pow_13.run(buf282, buf253, buf259, buf275, buf281, buf283, 22, 1024, stream=stream0)
        del buf253
        del buf259
        del buf275
        del buf281
        buf284 = empty_strided_cuda((2, 4, 1024), (4096, 1024, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_98, pow_25, variance_24, add_73, rsqrt_24, mul_114, hidden_24, hidden_states_60, hidden_25, hidden_26], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.slice, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_add_clone_mean_mul_pow_rsqrt_slice_14.run(buf282, buf283, arg127_1, buf284, 8192, stream=stream0)
        del arg127_1
        del buf282
        del buf283
        buf285 = empty_strided_cuda((8, 64), (64, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [to_98, pow_25, variance_24, add_73, rsqrt_24, mul_114, hidden_24, hidden_states_60, hidden_25, hidden_26], Original ATen: [aten._to_copy, aten.pow, aten.mean, aten.add, aten.rsqrt, aten.mul, aten.slice, aten.clone, aten._unsafe_view, aten.t, aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf284, (8, 1024), (1024, 1), 0), reinterpret_tensor(arg128_1, (1024, 64), (1, 1024), 0), out=buf285)
        del arg128_1
        del buf284
        buf286 = empty_strided_cuda((2, 64, 4), (256, 4, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [hidden_26, transpose_50, contiguous_50], Original ATen: [aten._unsafe_view, aten.add, aten.transpose, aten.clone]
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_add_clone_transpose_15.run(buf285, arg129_1, buf286, 128, 4, stream=stream0)
        del arg129_1
        del buf285
    return (buf286, )


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1 = args
        args.clear()
        partition0_args = [arg6_1, arg8_1, arg7_1, arg10_1, arg9_1, arg11_1, arg13_1, arg12_1, arg15_1, arg14_1, arg3_1, arg5_1, arg4_1, arg0_1, arg2_1, arg1_1, arg16_1, arg19_1, arg20_1, arg21_1, arg22_1, arg17_1, arg18_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1]
        del arg6_1, arg8_1, arg7_1, arg10_1, arg9_1, arg11_1, arg13_1, arg12_1, arg15_1, arg14_1, arg3_1, arg5_1, arg4_1, arg0_1, arg2_1, arg1_1, arg16_1, arg19_1, arg20_1, arg21_1, arg22_1, arg17_1, arg18_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1
        (buf286,) = self.partitions[0](partition0_args)
        del partition0_args
        return (buf286, )

runner = Runner(partitions=[partition_0,])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns


def get_args():
    from torch._dynamo.testing import rand_strided
    arg0_1 = rand_strided((2, 64, 4), (256, 4, 1), device='cuda:0', dtype=torch.bfloat16)
    arg1_1 = rand_strided((1024, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((2, 64, 4), (256, 4, 1), device='cuda:0', dtype=torch.bfloat16)
    arg4_1 = rand_strided((1024, 64), (64, 1), device='cuda:0', dtype=torch.bfloat16)
    arg5_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg6_1 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg7_1 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg8_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg9_1 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg10_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg11_1 = rand_strided((2, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg12_1 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg13_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg14_1 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg15_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg16_1 = rand_strided((2, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg17_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg18_1 = rand_strided((32768, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
    arg19_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg20_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg21_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg22_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg23_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg24_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg25_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg26_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg27_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg28_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg29_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg30_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg31_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg32_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg33_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg34_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg35_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg36_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg37_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg38_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg39_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg40_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg41_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg42_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg43_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg44_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg45_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg46_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg47_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg48_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg49_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg50_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg51_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg52_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg53_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg54_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg55_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg56_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg57_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg58_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg59_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg60_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg61_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg62_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg63_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg64_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg65_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg66_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg67_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg68_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg69_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg70_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg71_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg72_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg73_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg74_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg75_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg76_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg77_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg78_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg79_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg80_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg81_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg82_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg83_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg84_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg85_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg86_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg87_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg88_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg89_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg90_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg91_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg92_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg93_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg94_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg95_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg96_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg97_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg98_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg99_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg100_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg101_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg102_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg103_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg104_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg105_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg106_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg107_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg108_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg109_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg110_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg111_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg112_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg113_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg114_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg115_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg116_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg117_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg118_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg119_1 = rand_strided((2048, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg120_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg121_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg122_1 = rand_strided((1024, 2048), (2048, 1), device='cuda:0', dtype=torch.bfloat16)
    arg123_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg124_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg125_1 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg126_1 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
    arg127_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg128_1 = rand_strided((64, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
    arg129_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    return [arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
