
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 16, 'r0_': 8192},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=48, cc=121, major=12, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 4, 'num_store': 1, 'num_reduction': 3, 'backend_hash': 'BE9F3F68E84A48F2C239366BB10BBF3F0E1498DBECF7F747524E7B7961E82579', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'mix_order_reduction_allow_multi_stages': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'tiling_scores': {'x': 0, 'r0_': 1572864}}
)
@triton.jit
def triton_red_fused__safe_softmax_add_arange_exp_le_prepare_softmax_online_scalar_tensor_sub_view_where_7(in_out_ptr0, in_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 16
    r0_numel = 8192
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    x0 = xindex
    tmp1 = tl.load(in_ptr0 + (0))
    tmp2 = tl.broadcast_to(tmp1, [1, 1])
    _tmp15 = tl.full([XBLOCK, R0_BLOCK], False, tl.int1)
    _tmp19_max = tl.full([XBLOCK, R0_BLOCK], float('-inf'), tl.float32)
    _tmp19_sum = tl.zeros([XBLOCK, R0_BLOCK], tl.float32)
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp0 = tl.load(in_out_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
        tmp3 = r0_1
        tmp4 = tmp3 <= tmp2
        tmp5 = tl.full([1, 1], 0.0, tl.float32)
        tmp6 = tl.full([1, 1], float("-inf"), tl.float32)
        tmp7 = tl.where(tmp4, tmp5, tmp6)
        tmp8 = tmp7.to(tl.float32)
        tmp9 = tmp0 + tmp8
        tmp10 = tmp9 == tmp6
        tmp11 = tmp10 == 0
        tmp12 = tmp11.to(tl.int64)
        tmp13 = (tmp12 != 0)
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
        tmp16 = _tmp15 | tmp14
        _tmp15 = tl.where(r0_mask & xmask, tmp16, _tmp15)
        tmp18 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])

        _tmp19_max_next, _tmp19_sum_next = triton_helpers.online_softmax_combine(
            _tmp19_max, _tmp19_sum, tmp18, False
        )

        _tmp19_max = tl.where(r0_mask & xmask, _tmp19_max_next, _tmp19_max)
        _tmp19_sum = tl.where(r0_mask & xmask, _tmp19_sum_next, _tmp19_sum)
    tmp17 = _tmp15.to(tl.int8)
    tmp15 = triton_helpers.any(tmp17, 1)[:, None]

    tmp19, tmp20 = triton_helpers.online_softmax_reduce(
        _tmp19_max, _tmp19_sum, 1, False)
    tmp19 = tmp19[:, None]
    tmp20 = tmp20[:, None]
    tmp23 = tl.load(in_ptr0 + (0))
    tmp24 = tl.broadcast_to(tmp23, [1, 1])
    for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_1 = r0_index
        tmp22 = tl.load(in_out_ptr0 + (r0_1 + 8192*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
        tmp21 = tmp15 == 0
        tmp25 = r0_1
        tmp26 = tmp25 <= tmp24
        tmp27 = tl.full([1, 1], 0.0, tl.float32)
        tmp28 = tl.full([1, 1], float("-inf"), tl.float32)
        tmp29 = tl.where(tmp26, tmp27, tmp28)
        tmp30 = tmp29.to(tl.float32)
        tmp31 = tmp22 + tmp30
        tmp32 = tmp31 - tmp19
        tmp33 = libdevice.exp(tmp32)
        tmp34 = (tmp33 / tmp20)
        tmp35 = tl.where(tmp21, tmp27, tmp34)
        tl.store(in_out_ptr0 + (r0_1 + 8192*x0), tmp35, r0_mask & xmask)
