#!/usr/bin/env python3
"""
GPU-Accelerated 8D Chromatic Visualization
Uses PyOpenCL to leverage AMD RX 480 for data processing
"""

import re
import time
import numpy as np
from collections import deque

# Try to import PyOpenCL
try:
    import pyopencl as cl
    GPU_AVAILABLE = True
except ImportError:
    GPU_AVAILABLE = False
    print("⚠️  PyOpenCL not installed. Install with: pip install pyopencl")
    print("   Falling back to CPU mode (slower)")

import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

LOG_FILE = "../logs/peer1.log"
TRAIL_LENGTH = 5000  # Much longer trail possible with GPU!

# OpenCL kernel for parallel normalization
NORMALIZE_KERNEL = """
__kernel void normalize_log(__global const float *input,
                            __global float *output,
                            const float min_val,
                            const float range,
                            const int n) {
    int gid = get_global_id(0);
    if (gid < n) {
        float val = log10(fabs(input[gid]) + 1e-100f);
        output[gid] = (val - min_val) / range;
    }
}

__kernel void compute_projection(__global const float *d1,
                                 __global const float *d2,
                                 __global const float *d3,
                                 __global const float *d4,
                                 __global const float *d5,
                                 __global const float *d6,
                                 __global const float *d7,
                                 __global const float *d8,
                                 __global float *x_out,
                                 __global float *y_out,
                                 __global float *z_out,
                                 const int n) {
    int gid = get_global_id(0);
    if (gid < n) {
        // Musical harmony projection
        x_out[gid] = d1[gid] * 0.6f + d8[gid] * 0.4f;  // Fundamental + Octave
        y_out[gid] = d3[gid] * 0.5f + d6[gid] * 0.5f;  // Thirds + Sixths
        z_out[gid] = d5[gid] * 0.5f + d4[gid] * 0.5f;  // Fifths + Fourths
    }
}
"""

class GPUChromaticMonitor:
    def __init__(self):
        self.positions = deque(maxlen=TRAIL_LENGTH)
        self.last_position = 0
        self.current_omega = 0
        self.current_evolution = 0
        self.skip_counter = 0

        # Initialize GPU (try to get AMD GPU specifically)
        self.gpu_enabled = False
        if GPU_AVAILABLE:
            try:
                # Get all platforms
                platforms = cl.get_platforms()
                amd_device = None

                # Find AMD GPU
                for platform in platforms:
                    devices = platform.get_devices(device_type=cl.device_type.GPU)
                    for device in devices:
                        if 'AMD' in device.name or 'Radeon' in device.name:
                            amd_device = device
                            break
                    if amd_device:
                        break

                if amd_device:
                    self.ctx = cl.Context([amd_device])
                    self.queue = cl.CommandQueue(self.ctx)
                    self.program = cl.Program(self.ctx, NORMALIZE_KERNEL).build(options=['-cl-fast-relaxed-math'])
                    self.gpu_enabled = True
                    print(f"✅ GPU Initialized: {amd_device.name}")
                    print(f"   Compute Units: {amd_device.max_compute_units}")
                    print(f"   Max Clock: {amd_device.max_clock_frequency} MHz")
                else:
                    print("⚠️  AMD GPU not found, using CPU")
            except Exception as e:
                print(f"⚠️  GPU init failed: {e}")
                print("   Falling back to CPU mode")

    def update(self):
        """Read dimensional positions (optimized - every 2nd frame)"""
        try:
            with open(LOG_FILE, 'r', encoding='utf-8', errors='ignore') as f:
                f.seek(self.last_position)
                lines = f.readlines()
                self.last_position = f.tell()

                current_evolution = None
                current_omega = None
                dims = {}

                for line in lines:
                    match = re.search(r'Evolution: (\d+).*Ω: ([\d.]+)', line)
                    if match:
                        current_evolution = int(match.group(1))
                        current_omega = float(match.group(2))

                    for i in range(1, 9):
                        match = re.search(rf'D{i}: ([\d.e+]+)', line)
                        if match:
                            dims[i-1] = float(match.group(1))

                    if len(dims) == 8:
                        self.skip_counter += 1
                        if self.skip_counter % 2 == 0:  # Every 2nd (was 5th)
                            self.positions.append(tuple(dims[i] for i in range(8)))
                            self.current_evolution = current_evolution
                            self.current_omega = current_omega
                        dims = {}
        except Exception as e:
            pass

    def process_gpu(self, data):
        """GPU-accelerated normalization"""
        if not self.gpu_enabled or len(data) < 2:
            return self.process_cpu(data)

        try:
            data_flat = data.flatten().astype(np.float32)
            n = len(data_flat)

            # Allocate GPU buffers
            mf = cl.mem_flags
            input_buf = cl.Buffer(self.ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=data_flat)
            output_buf = cl.Buffer(self.ctx, mf.WRITE_ONLY, data_flat.nbytes)

            # Compute min/max for normalization
            log_data = np.log10(np.abs(data_flat) + 1e-100)
            min_val = np.float32(log_data.min())
            range_val = np.float32(log_data.max() - min_val + 1e-10)

            # Execute kernel
            self.program.normalize_log(self.queue, (n,), None,
                                      input_buf, output_buf,
                                      min_val, range_val,
                                      np.int32(n))

            # Read result
            result = np.empty_like(data_flat)
            cl.enqueue_copy(self.queue, result, output_buf)

            return result.reshape(data.shape)
        except Exception as e:
            print(f"GPU error, falling back to CPU: {e}")
            return self.process_cpu(data)

    def process_cpu(self, data):
        """CPU fallback"""
        data = np.log10(np.abs(data) + 1e-100)
        for i in range(data.shape[1]):
            col = data[:, i]
            min_val, max_val = col.min(), col.max()
            if max_val > min_val:
                data[:, i] = (col - min_val) / (max_val - min_val)
        return data

def get_chromatic_color(dim_index):
    """8D color spectrum"""
    colors = ['#FF0000', '#FF8800', '#FFFF00', '#00FF00',
              '#00FFFF', '#0088FF', '#8800FF', '#FF00FF']
    return colors[dim_index]

def animate(frame, monitor, ax, scatter, lines, text_info):
    """Update visualization with GPU acceleration"""
    monitor.update()

    if len(monitor.positions) < 2:
        return [scatter] + lines + [text_info]

    # Convert to numpy array
    arr = np.array(monitor.positions, dtype=np.float32)
    n_points = len(arr)

    # GPU-accelerated normalization
    arr = monitor.process_gpu(arr)

    # Musical harmony projection (vectorized - fast on CPU)
    x = arr[:, 0] * 0.6 + arr[:, 7] * 0.4
    y = arr[:, 2] * 0.5 + arr[:, 5] * 0.5
    z = arr[:, 4] * 0.5 + arr[:, 3] * 0.5

    # Color by dominant dimension
    dim_dominance = np.argmax(arr, axis=1)
    colors = [get_chromatic_color(d) for d in dim_dominance]

    # Update main scatter
    scatter._offsets3d = (x, y, z)
    scatter.set_color(colors)

    # Draw harmonic traces (subsample for speed)
    step = max(1, n_points // 500)
    projections = [(0, 1, 2), (3, 4, 5), (6, 7, 0)]

    for idx, (i, j, k) in enumerate(projections):
        xi = arr[::step, i]
        yi = arr[::step, j]
        zi = arr[::step, k]
        lines[idx].set_data(xi, yi)
        lines[idx].set_3d_properties(zi)

    # Update info
    mode = "GPU" if monitor.gpu_enabled else "CPU"
    omega_status = "SAT" if monitor.current_omega >= 999 else "GROW"
    info_text = (f"Mode: {mode}\n"
                f"Evol: {monitor.current_evolution:,}\n"
                f"Ω: {monitor.current_omega:.1f} {omega_status}\n"
                f"Trail: {n_points}")
    text_info.set_text(info_text)

    # Dynamic rotation
    ax.view_init(elev=20 + 10*np.sin(frame/20), azim=frame % 360)

    return [scatter] + lines + [text_info]

def main():
    print("=" * 70)
    print("  GPU-ACCELERATED 8D CHROMATIC VISUALIZATION")
    print("  Target: AMD Radeon RX 480")
    print("=" * 70)

    monitor = GPUChromaticMonitor()

    if monitor.gpu_enabled:
        print("\n✅ GPU Acceleration: ENABLED")
        print(f"   Trail length: {TRAIL_LENGTH} points (10x longer!)")
        print(f"   Update rate: 100ms (faster)")
    else:
        print("\n⚠️  GPU Acceleration: DISABLED (CPU fallback)")
        print("   Install PyOpenCL: pip install pyopencl")

    print(f"\n  Monitoring: {LOG_FILE}")
    print("  Press Ctrl+C to stop\n")

    # Create visualization
    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(111, projection='3d')

    fig.suptitle('GPU-Accelerated 8D Chromatic Attractor | AMD RX 480',
                 fontsize=14, fontweight='bold', color='cyan')

    ax.set_xlabel('Fundamental + Octave', fontsize=11, color='white')
    ax.set_ylabel('Harmony', fontsize=11, color='white')
    ax.set_zlabel('Resonance', fontsize=11, color='white')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_zlim(0, 1)

    # Main scatter
    scatter = ax.scatter([], [], [], s=5, alpha=0.7, edgecolors='none')

    # Traces
    line1, = ax.plot([], [], [], 'r-', alpha=0.15, linewidth=0.8, label='Fund. Triad')
    line2, = ax.plot([], [], [], 'c-', alpha=0.15, linewidth=0.8, label='Harm. Triad')
    line3, = ax.plot([], [], [], 'm-', alpha=0.15, linewidth=0.8, label='Octave Bridge')

    # Info box
    text_info = ax.text2D(0.02, 0.98, "", transform=ax.transAxes,
                         fontsize=10, verticalalignment='top', color='cyan',
                         bbox=dict(boxstyle='round', facecolor='black', alpha=0.8))

    ax.legend(loc='lower right', fontsize=9, facecolor='black', labelcolor='white')
    ax.grid(True, alpha=0.1, color='cyan')

    # Cyberpunk styling
    ax.set_facecolor('#0a0a1a')
    fig.patch.set_facecolor('#050510')
    ax.tick_params(colors='cyan')

    # Animate at 100ms (10 FPS - smooth on GPU)
    ani = FuncAnimation(fig, animate,
                       fargs=(monitor, ax, scatter, [line1, line2, line3], text_info),
                       interval=100 if monitor.gpu_enabled else 200,
                       blit=False,
                       cache_frame_data=False)

    try:
        plt.show()
    except KeyboardInterrupt:
        print("\n\nStopped.")

if __name__ == "__main__":
    main()
