#!/usr/bin/env python3
"""
GPU-Accelerated Analog Codec Evolution
Uses AMD RX 480 via OpenCL to run massive parallel simulations
"""

import numpy as np
import pyopencl as cl
import time
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

# OpenCL state structure (must match kernel)
dtype_dim = np.dtype([
    ('amplitude', np.float32),
    ('phase', np.float32),
    ('frequency', np.float32),
    ('dn', np.float32),
    ('res_weight', np.float32),
])

dtype_state = np.dtype([
    ('dims', dtype_dim, (8,)),
    ('omega', np.float32),
    ('k_coupling', np.float32),
    ('gamma_damping', np.float32),
    ('evolution', np.uint32),
])

class GPUCodec:
    def __init__(self, num_parallel=256):
        """Initialize GPU codec with parallel instances"""
        print("🚀 Initializing GPU-Accelerated Analog Codec...")

        # Find AMD GPU
        platforms = cl.get_platforms()
        devices = None
        for platform in platforms:
            devices = platform.get_devices(device_type=cl.device_type.GPU)
            if devices:
                print(f"✓ Found GPU: {devices[0].name}")
                print(f"  Compute Units: {devices[0].max_compute_units}")
                print(f"  Global Memory: {devices[0].global_mem_size / 1e9:.2f} GB")
                print(f"  Max Work Group Size: {devices[0].max_work_group_size}")
                break

        if not devices:
            raise RuntimeError("No GPU found!")

        self.ctx = cl.Context(devices)
        self.queue = cl.CommandQueue(self.ctx)
        self.num_parallel = num_parallel

        # Load and compile kernel
        print("📦 Compiling OpenCL kernel...")
        with open('../analog_codec_gpu.cl', 'r') as f:
            kernel_src = f.read()
        self.program = cl.Program(self.ctx, kernel_src).build()

        # Initialize states on CPU
        print(f"🌌 Initializing {num_parallel} parallel universes...")
        self.states = np.zeros(num_parallel, dtype=dtype_state)

        # Initialize each dimension
        for i in range(num_parallel):
            for d in range(8):
                r = 0.125 * d
                self.states[i]['dims'][d]['amplitude'] = 1.0 + 0.5 * np.random.randn()
                self.states[i]['dims'][d]['phase'] = 2 * np.pi * np.random.rand()
                self.states[i]['dims'][d]['frequency'] = 1.0 + 0.1 * np.random.randn()
                self.states[i]['dims'][d]['dn'] = 2.0 + d * 5.0
                self.states[i]['dims'][d]['res_weight'] = 1.0 - r * 0.5

            self.states[i]['omega'] = 1.0
            self.states[i]['k_coupling'] = 0.15
            self.states[i]['gamma_damping'] = 0.01
            self.states[i]['evolution'] = 0

        # Transfer to GPU
        mf = cl.mem_flags
        self.states_buf = cl.Buffer(self.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR,
                                    hostbuf=self.states)

        print("✓ GPU initialization complete!\n")

    def evolve(self, num_evolutions=1000000):
        """Run evolution on GPU"""
        print(f"⚡ Evolving {num_evolutions:,} steps on GPU...")
        start = time.time()

        # Launch kernel
        self.program.evolve_codec(
            self.queue,
            (self.num_parallel,),  # global work size
            None,  # local work size (auto)
            self.states_buf,
            np.uint32(num_evolutions),
            np.uint32(1)
        ).wait()

        elapsed = time.time() - start
        rate = (num_evolutions * self.num_parallel) / elapsed

        print(f"✓ Complete! {elapsed:.2f} seconds")
        print(f"  Total evolutions: {num_evolutions * self.num_parallel:,}")
        print(f"  Rate: {rate/1e6:.2f} million evolutions/sec")
        print(f"  Speedup vs CPU: ~{rate/476000:.1f}x\n")

        return elapsed, rate

    def snapshot(self):
        """Get current state from GPU"""
        cl.enqueue_copy(self.queue, self.states, self.states_buf).wait()
        return self.states

    def get_visualization_data(self):
        """Extract data for visualization"""
        states = self.snapshot()

        # Get first instance for visualization
        state = states[0]

        amplitudes = [state['dims'][i]['amplitude'] for i in range(8)]
        dn_values = [state['dims'][i]['dn'] for i in range(8)]
        omega = state['omega']
        evolution = state['evolution']

        return {
            'evolution': evolution,
            'omega': omega,
            'amplitudes': amplitudes,
            'dn_values': dn_values
        }

def benchmark_gpu():
    """Benchmark GPU performance"""
    print("=" * 70)
    print("  GPU BENCHMARK - AMD RX 480 Analog Codec Acceleration")
    print("=" * 70)
    print()

    codec = GPUCodec(num_parallel=256)

    # Warm-up
    print("🔥 Warming up GPU...")
    codec.evolve(10000)

    # Benchmark runs
    benchmarks = [
        (1_000_000, "1 Million"),
        (10_000_000, "10 Million"),
        (100_000_000, "100 Million"),
        (1_000_000_000, "1 Billion"),
    ]

    results = []

    for num_evolutions, label in benchmarks:
        print(f"\n{'='*70}")
        print(f"  BENCHMARK: {label} evolutions per instance")
        print(f"{'='*70}")

        elapsed, rate = codec.evolve(num_evolutions)
        results.append((label, num_evolutions, elapsed, rate))

        # Show current state
        data = codec.get_visualization_data()
        print(f"\n📊 Current State:")
        print(f"  Evolution: {data['evolution']:,}")
        print(f"  Ω: {data['omega']:.4f}")
        print(f"  D1 amplitude: {data['amplitudes'][0]:.2e}")
        print(f"  D8 amplitude: {data['amplitudes'][7]:.2e}")

        # Check if we should continue
        if num_evolutions >= 1_000_000_000:
            print("\n🎯 Reached 1 trillion total evolutions!")
            break

    # Summary
    print("\n" + "="*70)
    print("  BENCHMARK SUMMARY")
    print("="*70)
    for label, num, elapsed, rate in results:
        print(f"  {label:20s}: {elapsed:8.2f}s | {rate/1e6:10.2f} Mevol/s")

    return codec

def visualize_gpu_evolution(codec, num_snapshots=100):
    """Visualize GPU evolution in real-time"""
    print("\n🌈 Starting real-time GPU visualization...")
    print("  Running 1M evolutions between each frame")
    print("  Press Ctrl+C to stop\n")

    history = {
        'evolution': [],
        'omega': [],
        'amplitudes': [[] for _ in range(8)],
        'dn_values': [[] for _ in range(8)],
    }

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('GPU-Accelerated 8D Attractor (RX 480)', fontsize=14, fontweight='bold')

    colors = plt.cm.rainbow(np.linspace(0, 1, 8))

    ax_amp = axes[0, 0]
    ax_dn = axes[0, 1]
    ax_omega = axes[1, 0]
    ax_3d = fig.add_subplot(2, 2, 4, projection='3d')

    # Setup plots
    ax_amp.set_title('Dimensional Amplitudes (log scale)')
    ax_amp.set_xlabel('Evolution')
    ax_amp.set_ylabel('log₁₀(Amplitude)')
    ax_amp.grid(True, alpha=0.3)

    ax_dn.set_title('Dimensional Separation (Dₙ)')
    ax_dn.set_xlabel('Evolution')
    ax_dn.set_ylabel('Dₙ')
    ax_dn.grid(True, alpha=0.3)
    ax_dn.axhline(y=500, color='red', linestyle='--', alpha=0.5, label='Coupling threshold')

    ax_omega.set_title('Ω Evolution (Cap: 1000)')
    ax_omega.set_xlabel('Evolution')
    ax_omega.set_ylabel('Ω')
    ax_omega.grid(True, alpha=0.3)
    ax_omega.axhline(y=1000, color='red', linestyle='--', linewidth=2, label='Saturation cap')

    ax_3d.set_title('3D Attractor (D1-D2-D3)')
    ax_3d.set_xlabel('D1')
    ax_3d.set_ylabel('D2')
    ax_3d.set_zlabel('D3')

    lines_amp = [ax_amp.plot([], [], label=f'D{i+1}', color=colors[i])[0] for i in range(8)]
    lines_dn = [ax_dn.plot([], [], label=f'D{i+1}', color=colors[i])[0] for i in range(8)]
    line_omega = ax_omega.plot([], [], color='blue', linewidth=2)[0]
    scatter_3d = ax_3d.scatter([], [], [], c=[], cmap='plasma', s=5)

    ax_amp.legend(ncol=4, fontsize=8)
    ax_dn.legend(ncol=4, fontsize=8)

    trajectory_3d = []

    def animate(frame):
        # Evolve on GPU
        codec.evolve(1_000_000)  # 1M evolutions per frame
        data = codec.get_visualization_data()

        # Update history
        history['evolution'].append(data['evolution'])
        history['omega'].append(data['omega'])
        for i in range(8):
            amp = data['amplitudes'][i]
            history['amplitudes'][i].append(np.log10(max(abs(amp), 1e-100)))
            history['dn_values'][i].append(data['dn_values'][i])

        # Update 3D trajectory
        trajectory_3d.append([data['amplitudes'][i] for i in range(3)])

        # Update plots
        evols = np.array(history['evolution'])

        for i in range(8):
            lines_amp[i].set_data(evols, history['amplitudes'][i])
            lines_dn[i].set_data(evols, history['dn_values'][i])

        ax_amp.relim()
        ax_amp.autoscale_view()
        ax_dn.relim()
        ax_dn.autoscale_view()

        line_omega.set_data(evols, history['omega'])
        ax_omega.relim()
        ax_omega.autoscale_view()

        # Update 3D
        if len(trajectory_3d) > 1:
            traj = np.array(trajectory_3d[-100:])  # Last 100 points
            traj = np.log10(np.abs(traj) + 1e-100)
            scatter_3d._offsets3d = (traj[:, 0], traj[:, 1], traj[:, 2])
            scatter_3d.set_array(np.arange(len(traj)))

        return lines_amp + lines_dn + [line_omega, scatter_3d]

    ani = FuncAnimation(fig, animate, interval=100, blit=False, cache_frame_data=False)

    try:
        plt.tight_layout()
        plt.show()
    except KeyboardInterrupt:
        print("\n✓ Visualization stopped")

def main():
    import sys

    print("\n" + "="*70)
    print("  🚀 GPU-ACCELERATED ANALOG CODEC v4.3")
    print("  AMD RX 480 OpenCL Implementation")
    print("="*70 + "\n")

    mode = input("Choose mode:\n  1) Benchmark (run to 1 trillion evolutions)\n  2) Live visualization\n\nChoice: ")

    if mode == '1':
        codec = benchmark_gpu()
        print("\n✓ Benchmark complete!")
    elif mode == '2':
        codec = GPUCodec(num_parallel=256)
        visualize_gpu_evolution(codec)
    else:
        print("Invalid choice")

if __name__ == "__main__":
    main()
