
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 256}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 1536}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_cat_index_index_put_mul_neg_select_split_transpose_view_2(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = (xindex % 128)
    x1 = xindex // 128
    tmp0 = tl.load(in_ptr0 + (0))
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK])
    tmp7 = tl.load(in_ptr1 + (x2), xmask).to(tl.float32)
    tmp2 = tl.full([XBLOCK], 8192, tl.int32)
    tmp3 = tmp1 + tmp2
    tmp4 = tmp1 < 0
    tmp5 = tl.where(tmp4, tmp3, tmp1)
    tl.device_assert((0 <= tmp5) & (tmp5 < 8192), "index out of bounds: 0 <= tmp5 < 8192")
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tl.full([XBLOCK], 32768, tl.int32)
    tmp10 = tmp1 + tmp9
    tmp11 = tl.where(tmp4, tmp10, tmp1)
    tl.device_assert((0 <= tmp11) & (tmp11 < 32768), "index out of bounds: 0 <= tmp11 < 32768")
    tmp13 = tl.load(in_ptr2 + (x0 + 128*tmp11), xmask).to(tl.float32)
    tmp14 = tmp13.to(tl.float32)
    tmp15 = tmp8 * tmp14
    tmp16 = x0
    tmp17 = tl.full([1], 0, tl.int64)
    tmp18 = tmp16 >= tmp17
    tmp19 = tl.full([1], 64, tl.int64)
    tmp20 = tmp16 < tmp19
    tmp21 = tl.load(in_ptr1 + (64 + 128*x1 + (x0)), tmp20 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp22 = tmp21.to(tl.float32)
    tmp23 = -tmp22
    tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
    tmp25 = tl.where(tmp20, tmp23, tmp24)
    tmp26 = tmp16 >= tmp19
    tmp27 = tl.full([1], 128, tl.int64)
    tmp28 = tmp16 < tmp27
    tmp29 = tl.load(in_ptr1 + (128*x1 + ((-64) + x0)), tmp26 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
    tmp30 = tmp29.to(tl.float32)
    tmp31 = tl.full(tmp30.shape, 0.0, tmp30.dtype)
    tmp32 = tl.where(tmp26, tmp30, tmp31)
    tmp33 = tl.where(tmp20, tmp25, tmp32)
    tmp34 = tl.load(in_ptr3 + (x0 + 128*tmp11), xmask).to(tl.float32)
    tmp35 = tmp34.to(tl.float32)
    tmp36 = tmp33 * tmp35
    tmp37 = tmp15 + tmp36
    tmp38 = tmp37.to(tl.float32)
    tl.store(out_ptr0 + (x0 + 128*tmp5 + 1048576*x1), tmp38, xmask)
