#!/usr/bin/env python3
import torch
from transformers import pipeline
from transformers.utils import is_flash_attn_2_available
import json, subprocess, os

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")

# Generate audio with VoxCPM2
print("\nGenerating audio with VoxCPM2...")
result = subprocess.run([
    'curl', '-X', 'POST', 'http://192.168.1.127:8101/v1/audio/speech',
    '-H', 'Content-Type: application/json',
    '-d', json.dumps({
        'model': 'openbmb/VoxCPM2',
        'input': '(voix masculine professionnelle, ton sérieux) Bonjour et bienvenue dans ce rapport de chantier. Nous sommes le 23 avril 2026 et nous intervenons sur le chantier de rénovation du centre-ville de Montpellier. L équipe au complet est présente : trois maçons, deux électriciens et un menuisier. Le chantier a démarré à 7 heures du matin et les conditions météorologiques sont excellentes, avec un ciel dégagé et une température de 22 degrés.',
        'voice': 'default',
        'cfg_value': 2.0,
        'inference_timesteps': 10
    }),
    '-o', '/tmp/rcc_test.wav'
], capture_output=True, text=True)

if result.returncode != 0 or not os.path.exists('/tmp/rcc_test.wav') or os.path.getsize('/tmp/rcc_test.wav') < 1000:
    print(f"✗ Audio gen failed: {result.stderr or result.returncode}")
    exit(1)

size = os.path.getsize('/tmp/rcc_test.wav')
print(f"✓ Audio saved: /tmp/rcc_test.wav ({size/1024:.0f} KB)")

# Transcribe with Whisper directly using transformers (no pyannote)
print("\nTranscribing with Whisper Large v3 (flash-attn)...")

pipe = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-large-v3",
    dtype=torch.float16,
    device="cuda:0",
    model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)

print(f"Flash Attn 2: {is_flash_attn_2_available()}")

outputs = pipe(
    "/tmp/rcc_test.wav",
    chunk_length_s=30,
    batch_size=8,
    return_timestamps=True,
)

print(f"✓ Transcript: {outputs['text'][:400]}...")
print(f"Length: {len(outputs['text'])} chars")
