#!/usr/bin/env python3
"""
Multi-Dimensional Phase Space Geometry Visualization
Shows 2D projections of the 5D attractor geometry
"""

import re
import time
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from collections import deque

LOG_FILE = "../logs/peer1.log"
TRAIL_LENGTH = 3000

class PhaseSpaceMonitor:
    def __init__(self):
        self.positions = deque(maxlen=TRAIL_LENGTH)
        self.last_position = 0
        self.current_omega = 0
        self.current_evolution = 0

    def update(self):
        """Read dimensional positions"""
        try:
            with open(LOG_FILE, 'r', encoding='utf-8', errors='ignore') as f:
                f.seek(self.last_position)
                lines = f.readlines()
                self.last_position = f.tell()

                current_evolution = None
                current_omega = None
                dims = {}

                for line in lines:
                    # Evolution line
                    match = re.search(r'Evolution: (\d+).*Ω: ([\d.]+)', line)
                    if match:
                        current_evolution = int(match.group(1))
                        current_omega = float(match.group(2))

                    # Parse all 8 dimensions
                    for i in range(1, 9):
                        match = re.search(rf'D{i}: ([\d.e+]+)', line)
                        if match:
                            dims[i-1] = float(match.group(1))

                    # Complete snapshot
                    if len(dims) == 8:
                        self.positions.append(tuple(dims[i] for i in range(8)))
                        self.current_evolution = current_evolution
                        self.current_omega = current_omega
                        dims = {}
        except Exception as e:
            print(f"Error: {e}")

def normalize_dim(data):
    """Normalize dimension data"""
    if len(data) < 2:
        return data
    data = np.log10(np.abs(data) + 1e-100)
    min_val, max_val = data.min(), data.max()
    if max_val > min_val:
        return (data - min_val) / (max_val - min_val)
    return data

def animate(frame, monitor, axes, scatters):
    """Update all phase space projections"""
    monitor.update()

    if len(monitor.positions) < 2:
        return scatters

    arr = np.array(monitor.positions)
    n_points = len(arr)

    # Normalize all dimensions
    for i in range(8):
        arr[:, i] = normalize_dim(arr[:, i])

    # Color gradient (time evolution)
    colors = np.linspace(0, 1, n_points)

    # Update each 2D projection
    projections = [
        (0, 1, "D1 vs D2"),
        (1, 2, "D2 vs D3"),
        (2, 3, "D3 vs D4"),
        (0, 2, "D1 vs D3"),
        (3, 4, "D4 vs D5"),
        (0, 4, "D1 vs D5"),
    ]

    for idx, (i, j, title) in enumerate(projections):
        ax = axes[idx // 3, idx % 3]
        scatter = scatters[idx]

        x, y = arr[:, i], arr[:, j]
        scatter.set_offsets(np.c_[x, y])
        scatter.set_array(colors)

        # Draw recent trail
        if n_points > 50:
            recent = min(500, n_points)
            ax.plot(x[-recent:], y[-recent:], 'cyan', alpha=0.2, linewidth=0.5)

    # Update title with status
    omega_status = "SATURATED" if monitor.current_omega >= 999 else "GROWING"
    axes[0, 0].figure.suptitle(
        f'Phase Space Geometry | Evolution: {monitor.current_evolution:,} | '
        f'Ω: {monitor.current_omega:.1f} ({omega_status}) | '
        f'Trail: {n_points} points',
        fontsize=13, fontweight='bold'
    )

    return scatters

def main():
    print("🌌 Starting Multi-Dimensional Phase Space Visualization...")
    print(f"📊 Monitoring: {LOG_FILE}")
    print("🔄 Showing 6 geometric projections")
    print("🎯 Press Ctrl+C to stop\n")

    monitor = PhaseSpaceMonitor()

    # Create 2x3 grid of phase space plots
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    fig.patch.set_facecolor('#0f0f1e')

    projections = [
        (0, 1, "D1 vs D2 (Core Coupling)"),
        (1, 2, "D2 vs D3 (Inner Ring)"),
        (2, 3, "D3 vs D4 (Mid Layer)"),
        (0, 2, "D1 vs D3 (Cross Section)"),
        (3, 4, "D4 vs D5 (Outer Core)"),
        (0, 4, "D1 vs D5 (Full Span)"),
    ]

    scatters = []

    for idx, (i, j, title) in enumerate(projections):
        ax = axes[idx // 3, idx % 3]
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_xlabel(f'D{i+1}', fontsize=10)
        ax.set_ylabel(f'D{j+1}', fontsize=10)
        ax.set_title(title, fontsize=11, fontweight='bold', color='white')
        ax.set_facecolor('#1a1a2e')
        ax.grid(True, alpha=0.2, color='gray')
        ax.tick_params(colors='white')

        # Create scatter plot
        scatter = ax.scatter([], [], c=[], cmap='plasma', s=1, alpha=0.6)
        scatters.append(scatter)

    plt.tight_layout()

    # Animate
    ani = FuncAnimation(fig, animate,
                       fargs=(monitor, axes, scatters),
                       interval=200,
                       blit=False,
                       cache_frame_data=False)

    try:
        plt.show()
    except KeyboardInterrupt:
        print("\n\n✅ Visualization stopped.")

if __name__ == "__main__":
    main()
