#!/usr/bin/env python3
"""
Batch analysis and visualization for analog codec logs.
Generates:
 - scaling_loglog.png: log(Dn) vs log(t)
 - scaling_vs_sqrtOmega.png: log(Dn) vs sqrt(omega)
 - coupling_heatmap.gif: animated coupling matrix over time
 - lyapunov_spectrum.png: finite-time growth rates per-dimension over time
 - pca_evolution.png: PCA 2D embedding of log-amplitude state over time

Usage:
  python3 scripts/analysis_visuals.py

Saves outputs to `analysis_outputs/`.
"""
import os
import re
import math
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import animation
from sklearn.decomposition import PCA

LOG_PATH = os.path.join(os.path.dirname(__file__), '..', 'logs', 'peer1.log')
OUT_DIR = os.path.join(os.path.dirname(__file__), '..', 'analysis_outputs')
if not os.path.exists(OUT_DIR):
    os.makedirs(OUT_DIR, exist_ok=True)

EV_RE = re.compile(r"Evolution:\s*(\d+)\s*│\s*Phase:.*│.*Ω:\s*([0-9\.]+)")
# Matches lines like: "  D1: 996701...00 [Dₙ:0.0]"
D_LINE_RE = re.compile(r"D(\d+):\s*([0-9Ee+\-\.]+)\s*\[Dₙ:([0-9Ee+\-\.]+)\]")
# fallback generic
GENERIC_D_RE = re.compile(r"D(\d+):\s*([^\[]+)\[Dₙ:([^\]]+)\]")


def parse_log(path):
    snapshots = []  # list of dicts: {evo, omega, D_amp[1..8], Dn[1..8]}
    with open(path, 'r', encoding='utf-8', errors='ignore') as f:
        data = f.read()
    # iterate over occurrences of 'Evolution:'
    idx = 0
    while True:
        m = EV_RE.search(data, idx)
        if not m:
            break
        evo = int(m.group(1))
        omega = float(m.group(2))
        # extract following 12-16 lines to find D entries
        start = m.end()
        block = data[start:start+800]
        D_amp = [None]*8
        Dn = [None]*8
        for m2 in D_LINE_RE.finditer(block):
            i = int(m2.group(1))
            if 1 <= i <= 8:
                D_amp[i-1] = m2.group(2).strip()
                Dn[i-1] = float(m2.group(3))
        # fallback generic parse per-line
        if any(x is None for x in D_amp):
            for line in block.splitlines():
                line = line.strip()
                m3 = GENERIC_D_RE.match(line)
                if m3:
                    i = int(m3.group(1))
                    if 1 <= i <= 8:
                        token = m3.group(2).strip()
                        D_amp[i-1] = token.split()[0]
                        try:
                            Dn[i-1] = float(m3.group(3))
                        except Exception:
                            pass
        snapshots.append({'evo': evo, 'omega': omega, 'D_amp': D_amp, 'Dn': Dn})
        idx = start
    return snapshots


def approx_log10_from_string(s):
    if s is None:
        return np.nan
    s = s.strip().replace(',', '')
    try:
        v = float(s)
        if v > 0 and math.isfinite(v):
            return math.log10(abs(v))
        if v == 0:
            return -np.inf
    except Exception:
        pass
    # fallback for huge integers: count digits
    m = re.match(r"([0-9]+)(?:\.|$)", s)
    if m:
        digits = len(m.group(1))
        # approximate log10 as digits-1 plus leading digits fraction
        leading = m.group(1)[:15]
        leading_val = float(leading)
        return (digits - len(leading)) + math.log10(leading_val)
    return np.nan


def build_matrices(snapshots):
    times = np.array([s['evo'] for s in snapshots], dtype=np.float64)
    omegas = np.array([s['omega'] for s in snapshots], dtype=np.float64)
    A = np.zeros((len(snapshots), 8))
    Dn = np.zeros((len(snapshots), 8))
    for i, s in enumerate(snapshots):
        for j in range(8):
            A[i,j] = approx_log10_from_string(s['D_amp'][j])
            Dn[i,j] = s['Dn'][j] if s['Dn'][j] is not None else np.nan
    return times, omegas, A, Dn


def plot_scaling(times, omegas, A, Dn):
    # For each dimension, fit log(Dn) vs log(t)
    # use Dn (not amplitude) for scaling laws
    plt.figure(figsize=(10,6))
    for j in range(8):
        valid = ~np.isnan(Dn[:,j]) & (Dn[:,j] > 0)
        if np.sum(valid) < 5:
            continue
        xs = np.log10(times[valid])
        ys = np.log10(Dn[valid])
        coef = np.polyfit(xs, ys, 1)
        label = f'D{j+1} slope={coef[0]:.3f}'
        plt.plot(xs, ys, '.', alpha=0.6)
        plt.plot(xs, np.polyval(coef, xs), '-', label=label)
    plt.xlabel('log10(Evolution)')
    plt.ylabel('log10(Dn)')
    plt.title('Scaling: log10(Dn) vs log10(Evolution)')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(OUT_DIR, 'scaling_loglog.png'), dpi=150)
    plt.close()

    # plot log(Dn) vs sqrt(omega)
    plt.figure(figsize=(10,6))
    for j in range(8):
        valid = ~np.isnan(Dn[:,j]) & (Dn[:,j] > 0)
        if np.sum(valid) < 5:
            continue
        xs = np.sqrt(omegas[valid])
        ys = np.log10(Dn[valid])
        coef = np.polyfit(xs, ys, 1)
        label = f'D{j+1} slope={coef[0]:.3f}'
        plt.plot(xs, ys, '.', alpha=0.6)
        plt.plot(xs, np.polyval(coef, xs), '-', label=label)
    plt.xlabel('sqrt(omega)')
    plt.ylabel('log10(Dn)')
    plt.title('Scaling: log10(Dn) vs sqrt(omega)')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(OUT_DIR, 'scaling_vs_sqrtOmega.png'), dpi=150)
    plt.close()


def animate_coupling(Dn, times):
    # compute coupling matrices over time
    T = Dn.shape[0]
    mats = np.zeros((T,8,8))
    for t in range(T):
        for i in range(8):
            for j in range(8):
                if np.isnan(Dn[t,i]) or np.isnan(Dn[t,j]):
                    mats[t,i,j] = 0.0
                else:
                    diff = abs(Dn[t,i] - Dn[t,j])
                    mats[t,i,j] = math.exp(-diff/50.0)
    # create animation frames
    fig, ax = plt.subplots(figsize=(5,5))
    im = ax.imshow(mats[0], vmin=0, vmax=1, cmap='viridis')
    ax.set_xticks(range(8)); ax.set_yticks(range(8))
    ax.set_xticklabels([f'D{i+1}' for i in range(8)]); ax.set_yticklabels([f'D{i+1}' for i in range(8)])
    plt.colorbar(im, fraction=0.046, pad=0.04)
    title = ax.text(0.5,1.05,'', size=12, ha='center', transform=ax.transAxes)

    def update(frame):
        im.set_data(mats[frame])
        title.set_text(f'Evolution: {int(times[frame])}')
        return (im, title)

    anim = animation.FuncAnimation(fig, update, frames=range(0,len(times), max(1,len(times)//100)), blit=False)
    outpath = os.path.join(OUT_DIR, 'coupling_heatmap.gif')
    anim.save(outpath, writer='imagemagick', fps=6)
    plt.close()


def finite_time_lyapunov(A, times):
    # A is log10 amplitude matrix (T x 8)
    # compute finite difference growth rates: d(log A)/dt
    # time unit = evolution count
    T = A.shape[0]
    dt = np.diff(times)
    # avoid zero dt
    dt[dt==0] = 1
    lambdas = np.zeros((T-1, 8))
    for t in range(T-1):
        for j in range(8):
            y1 = A[t,j]
            y2 = A[t+1,j]
            if np.isfinite(y1) and np.isfinite(y2):
                # convert log10 to ln for growth rate
                l1 = y1 * math.log(10)
                l2 = y2 * math.log(10)
                lambdas[t,j] = (l2 - l1) / dt[t]
            else:
                lambdas[t,j] = np.nan
    # plot heatmap of lambdas over time
    plt.figure(figsize=(10,6))
    plt.imshow(lambdas.T, aspect='auto', cmap='seismic', vmin=-0.01, vmax=0.01)
    plt.colorbar(label='lambda (1/evolution)')
    plt.yticks(range(8), [f'D{i+1}' for i in range(8)])
    plt.xlabel('snapshot index')
    plt.title('Finite-time Lyapunov approx (d ln A / dt)')
    plt.savefig(os.path.join(OUT_DIR, 'lyapunov_spectrum.png'), dpi=150)
    plt.close()


def pca_embedding(A, times):
    # use finite windowed PCA on log-amplitude state
    # replace -inf with very small number
    A2 = np.copy(A)
    A2[~np.isfinite(A2)] = np.nanmin(A2[np.isfinite(A2)]) - 1
    pca = PCA(n_components=2)
    X = pca.fit_transform(A2)
    plt.figure(figsize=(8,6))
    sc = plt.scatter(X[:,0], X[:,1], c=np.arange(X.shape[0]), cmap='plasma', s=6)
    plt.colorbar(sc, label='snapshot index')
    plt.xlabel('PC1'); plt.ylabel('PC2')
    plt.title('PCA embedding of log-amplitude state over time')
    plt.savefig(os.path.join(OUT_DIR, 'pca_evolution.png'), dpi=150)
    plt.close()


def main():
    print('Parsing log...')
    snapshots = parse_log(LOG_PATH)
    if len(snapshots) == 0:
        print('No snapshots parsed. Make sure logs/peer1.log exists and contains Evolution snapshots.')
        return
    print(f'Parsed {len(snapshots)} snapshots.')
    times, omegas, A, Dn = build_matrices(snapshots)
    print('Building scaling plots...')
    plot_scaling(times, omegas, A, Dn)
    print('Animating coupling heatmap... (this may take a moment)')
    try:
        animate_coupling(Dn, times)
    except Exception as e:
        print('Failed to create GIF (imagemagick required). Error:', e)
        # fallback: save one frame
        import matplotlib.pyplot as plt
        plt.figure(figsize=(5,5)); plt.imshow(np.exp(-np.abs(Dn[-1,None]-Dn[-1,None].T)/50.0), vmin=0, vmax=1, cmap='viridis'); plt.colorbar(); plt.title('Coupling snapshot'); plt.savefig(os.path.join(OUT_DIR,'coupling_snapshot.png'))
    print('Computing finite-time Lyapunov approx...')
    finite_time_lyapunov(A, times)
    print('Computing PCA embedding...')
    pca_embedding(A, times)
    print('All outputs saved to', OUT_DIR)

if __name__ == '__main__':
    main()
