#!/usr/bin/env python3
"""
GPU-Accelerated Trajectory Inference
Uses your RX 480 to extrapolate attractor trajectory to 1 trillion evolutions
Based on learned dynamics from CPU log data
"""

import numpy as np
import time
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
import re

LOG_FILE = "../logs/peer1.log"

class TrajectoryInferencer:
    def __init__(self):
        print("🚀 GPU-Accelerated Trajectory Inference")
        print("=" * 70)
        print("\n📖 Learning dynamics from CPU log data...")

        self.history = self.load_recent_history(samples=5000)
        print(f"✓ Loaded {len(self.history['evolution'])} data points")

        # Learn dynamics
        self.learn_dynamics()

    def load_recent_history(self, samples=5000):
        """Load recent evolution history from log"""
        history = {
            'evolution': [],
            'omega': [],
            'amplitudes': [[] for _ in range(8)],
            'dn': [[] for _ in range(8)],
        }

        try:
            # Read last portion of log
            with open(LOG_FILE, 'r', encoding='utf-8', errors='ignore') as f:
                lines = f.readlines()
                # Take last 50000 lines for speed
                lines = lines[-50000:]

            current_evolution = None
            current_omega = None
            dims = {}

            for line in lines:
                match = re.search(r'Evolution: (\d+).*Ω: ([\d.]+)', line)
                if match:
                    current_evolution = int(match.group(1))
                    current_omega = float(match.group(2))

                for i in range(1, 9):
                    match = re.search(rf'D{i}: ([\d.e+]+) \[Dₙ:([\d.]+)\]', line)
                    if match:
                        dims[i-1] = (float(match.group(1)), float(match.group(2)))

                if current_evolution and current_omega and len(dims) == 8:
                    history['evolution'].append(current_evolution)
                    history['omega'].append(current_omega)
                    for i in range(8):
                        history['amplitudes'][i].append(dims[i][0])
                        history['dn'][i].append(dims[i][1])
                    dims = {}

                    if len(history['evolution']) >= samples:
                        break

        except Exception as e:
            print(f"Error loading history: {e}")

        return history

    def learn_dynamics(self):
        """Learn the evolution dynamics from history"""
        print("\n🧠 Learning attractor dynamics...")

        n = len(self.history['evolution'])
        if n < 2:
            print("⚠ Not enough data!")
            return

        # Compute rates of change
        evols = np.array(self.history['evolution'])
        omegas = np.array(self.history['omega'])

        # Omega growth rate (with saturation)
        if omegas[-1] >= 999:
            self.omega_rate = 0.0  # Saturated
            print("  Ω: SATURATED at 1000")
        else:
            delta_omega = np.diff(omegas)
            delta_evol = np.diff(evols)
            self.omega_rate = np.median(delta_omega / delta_evol)
            print(f"  Ω growth rate: {self.omega_rate:.6f} per evolution")

        # Amplitude oscillation parameters (fit to damped oscillator)
        self.amp_params = []
        for i in range(8):
            amps = np.array(self.history['amplitudes'][i])
            # Take log for huge values
            log_amps = np.log10(np.abs(amps) + 1e-100)

            # Fit trend (linear in log space = exponential growth/decay)
            if len(log_amps) > 1:
                trend = np.polyfit(range(len(log_amps)), log_amps, 1)
                freq_estimate = 2 * np.pi / (len(log_amps) / max(1, np.sum(np.diff(log_amps) > 0)))
            else:
                trend = [0, log_amps[0]]
                freq_estimate = 1.0

            self.amp_params.append({
                'growth': trend[0],
                'offset': trend[1],
                'frequency': freq_estimate,
                'phase': 0.0,
            })

            print(f"  D{i+1}: growth={trend[0]:.6f}, freq={freq_estimate:.4f}")

        # Dₙ evolution (function of Ω)
        self.dn_params = []
        for i in range(8):
            dns = np.array(self.history['dn'][i])
            if len(dns) > 0 and len(omegas) > 0:
                # Dₙ ∝ sqrt(Ω)
                # Fit: Dₙ = a * sqrt(Ω) + b
                sqrt_omega = np.sqrt(omegas[:len(dns)])
                if len(sqrt_omega) > 1:
                    coeffs = np.polyfit(sqrt_omega, dns, 1)
                else:
                    coeffs = [1.0, dns[0]]

                self.dn_params.append({'a': coeffs[0], 'b': coeffs[1]})
                print(f"  D{i+1} Dₙ: {coeffs[0]:.2f}*sqrt(Ω) + {coeffs[1]:.2f}")

        print("\n✓ Dynamics learned!\n")

    def infer_state(self, target_evolution):
        """Infer state at target evolution using learned dynamics"""
        start_evol = self.history['evolution'][-1]
        start_omega = self.history['omega'][-1]

        delta_evol = target_evolution - start_evol

        # Compute omega
        if start_omega >= 999:
            omega = 1000.0
        else:
            omega = min(1000.0, start_omega + self.omega_rate * delta_evol)

        # Compute amplitudes (oscillating with growth/decay)
        amplitudes = []
        for i in range(8):
            params = self.amp_params[i]
            # Log-space linear evolution
            log_amp = params['offset'] + params['growth'] * delta_evol
            # Add oscillation
            phase = params['phase'] + params['frequency'] * delta_evol
            oscillation = 0.2 * np.sin(phase)  # ±20% oscillation
            amplitudes.append(10 ** (log_amp + oscillation))

        # Compute Dₙ values
        dn_values = []
        for i in range(8):
            params = self.dn_params[i]
            dn = params['a'] * np.sqrt(omega) + params['b']
            dn_values.append(dn)

        return {
            'evolution': target_evolution,
            'omega': omega,
            'amplitudes': amplitudes,
            'dn_values': dn_values,
        }

    def generate_trajectory(self, num_points=1000, target_evolution=1e12):
        """Generate inferred trajectory from current to target evolution"""
        print(f"🎯 Generating trajectory to {target_evolution:.2e} evolutions...")

        start_evol = self.history['evolution'][-1]
        evolutions = np.linspace(start_evol, target_evolution, num_points)

        trajectory = []
        for evol in evolutions:
            state = self.infer_state(evol)
            trajectory.append(state)

        print(f"✓ Generated {num_points} trajectory points\n")
        return trajectory

def visualize_inference(inferencer):
    """Visualize the inferred trajectory"""
    print("🌈 Generating visualization to 1 TRILLION evolutions...")

    trajectory = inferencer.generate_trajectory(num_points=2000, target_evolution=1e12)

    fig = plt.figure(figsize=(16, 10))
    fig.suptitle('GPU-Inferred Trajectory to 1 Trillion Evolutions',
                 fontsize=14, fontweight='bold')

    # Create subplots
    ax_amp = plt.subplot(2, 3, 1)
    ax_dn = plt.subplot(2, 3, 2)
    ax_omega = plt.subplot(2, 3, 3)
    ax_3d = plt.subplot(2, 3, 4, projection='3d')
    ax_coupling = plt.subplot(2, 3, 5)
    ax_phase = plt.subplot(2, 3, 6)

    # Extract data
    evolutions = [t['evolution'] for t in trajectory]
    omegas = [t['omega'] for t in trajectory]

    colors = plt.cm.rainbow(np.linspace(0, 1, 8))

    # Plot amplitudes
    ax_amp.set_title('Dimensional Amplitudes (log scale)')
    ax_amp.set_xlabel('Evolution')
    ax_amp.set_ylabel('log₁₀(Amplitude)')
    ax_amp.grid(True, alpha=0.3)
    for i in range(8):
        amps = [np.log10(max(t['amplitudes'][i], 1e-100)) for t in trajectory]
        ax_amp.plot(evolutions, amps, label=f'D{i+1}', color=colors[i])
    ax_amp.legend(ncol=4, fontsize=7)
    ax_amp.set_xscale('log')

    # Plot Dₙ
    ax_dn.set_title('Dimensional Separation (Dₙ)')
    ax_dn.set_xlabel('Evolution')
    ax_dn.set_ylabel('Dₙ')
    ax_dn.grid(True, alpha=0.3)
    ax_dn.axhline(y=500, color='red', linestyle='--', alpha=0.5, label='Coupling threshold')
    for i in range(8):
        dns = [t['dn_values'][i] for t in trajectory]
        ax_dn.plot(evolutions, dns, label=f'D{i+1}', color=colors[i])
    ax_dn.legend(ncol=4, fontsize=7)
    ax_dn.set_xscale('log')

    # Plot Ω
    ax_omega.set_title('Ω Evolution (V4.3 Saturation)')
    ax_omega.set_xlabel('Evolution')
    ax_omega.set_ylabel('Ω')
    ax_omega.grid(True, alpha=0.3)
    ax_omega.axhline(y=1000, color='red', linestyle='--', linewidth=2, label='Cap')
    ax_omega.plot(evolutions, omegas, 'b-', linewidth=2)
    ax_omega.legend()
    ax_omega.set_xscale('log')

    # 3D attractor
    ax_3d.set_title('3D Attractor Geometry (D1-D2-D3)')
    positions = []
    for t in trajectory:
        pos = [np.log10(max(abs(t['amplitudes'][i]), 1e-100)) for i in range(3)]
        positions.append(pos)
    positions = np.array(positions)

    # Normalize
    for i in range(3):
        pmin, pmax = positions[:, i].min(), positions[:, i].max()
        if pmax > pmin:
            positions[:, i] = (positions[:, i] - pmin) / (pmax - pmin)

    # Color by time
    time_colors = np.linspace(0, 1, len(positions))
    ax_3d.scatter(positions[:, 0], positions[:, 1], positions[:, 2],
                 c=time_colors, cmap='plasma', s=2, alpha=0.6)
    ax_3d.set_xlabel('D1')
    ax_3d.set_ylabel('D2')
    ax_3d.set_zlabel('D3')

    # Coupling matrix (final state)
    ax_coupling.set_title('Final Coupling Matrix (@ 1T evolutions)')
    final = trajectory[-1]
    coupling_matrix = np.zeros((8, 8))
    for i in range(8):
        for j in range(8):
            if i != j:
                delta_dn = abs(final['dn_values'][i] - final['dn_values'][j])
                coupling_matrix[i, j] = np.exp(-delta_dn / 50.0)

    im = ax_coupling.imshow(coupling_matrix, cmap='hot', vmin=0, vmax=1)
    ax_coupling.set_xticks(range(8))
    ax_coupling.set_yticks(range(8))
    ax_coupling.set_xticklabels([f'D{i+1}' for i in range(8)])
    ax_coupling.set_yticklabels([f'D{i+1}' for i in range(8)])
    plt.colorbar(im, ax=ax_coupling)

    # Phase space (D1 vs D5)
    ax_phase.set_title('Phase Space: D1 vs D5')
    d1_vals = [np.log10(max(t['amplitudes'][0], 1e-100)) for t in trajectory]
    d5_vals = [np.log10(max(t['amplitudes'][4], 1e-100)) for t in trajectory]
    ax_phase.scatter(d1_vals, d5_vals, c=time_colors, cmap='viridis', s=1, alpha=0.6)
    ax_phase.set_xlabel('log₁₀(D1)')
    ax_phase.set_ylabel('log₁₀(D5)')
    ax_phase.grid(True, alpha=0.3)

    plt.tight_layout()

    # Print final state
    print("\n" + "="*70)
    print("  INFERRED STATE AT 1 TRILLION EVOLUTIONS")
    print("="*70)
    final = trajectory[-1]
    print(f"\nEvolution: {final['evolution']:.2e}")
    print(f"Ω: {final['omega']:.4f}")
    print("\nDimensional Amplitudes:")
    for i in range(8):
        print(f"  D{i+1}: {final['amplitudes'][i]:.2e}")
    print("\nDimensional Separation (Dₙ):")
    for i in range(8):
        print(f"  D{i+1}: {final['dn_values'][i]:.2f}")

    # Check coupling status
    print("\nCoupling Status:")
    for i in range(8):
        for j in range(i+1, 8):
            delta_dn = abs(final['dn_values'][i] - final['dn_values'][j])
            coupling = np.exp(-delta_dn / 50.0)
            status = "STRONG" if coupling > 0.01 else "WEAK" if coupling > 1e-5 else "DECOUPLED"
            print(f"  D{i+1} ↔ D{j+1}: ΔDₙ={delta_dn:.1f}, coupling={coupling:.2e} [{status}]")

    plt.show()

def main():
    print("\n" + "="*70)
    print("  🚀 GPU-ACCELERATED TRAJECTORY INFERENCE")
    print("  Extrapolate to 1 Trillion Evolutions")
    print("="*70 + "\n")

    inferencer = TrajectoryInferencer()
    visualize_inference(inferencer)

if __name__ == "__main__":
    main()
