import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from tqdm import tqdm
from joblib import Parallel, delayed

# Extended primes list (up to 1000)
PRIMES = [
    2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
    73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
    157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233,
    239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317,
    331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419,
    421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503,
    509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607,
    613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701,
    709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811,
    821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997
]

phi = (1 + np.sqrt(5)) / 2
fib_cache = {}

def fib_real(n):
    if n in fib_cache:
        return fib_cache[n]
    from math import cos, pi, sqrt
    phi_inv = 1 / phi
    if n > 100:
        return 0.0
    term1 = phi**n / sqrt(5)
    term2 = (phi_inv**n) * cos(pi * n)
    result = term1 - term2
    fib_cache[n] = result
    return result

def D(n, beta, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0):
    Fn_beta = fib_real(n + beta)
    idx = int(np.floor(n + beta) + len(PRIMES)) % len(PRIMES)
    Pn_beta = PRIMES[idx]
    dyadic = base ** (n + beta)
    val = scale * phi * Fn_beta * dyadic * Pn_beta * Omega
    val = np.maximum(val, 1e-30)
    return np.sqrt(val) * (r ** k)

def invert_D(value, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500, steps=500):
    candidates = []
    log_val = np.log10(max(abs(value), 1e-30))
    scale_factors = np.logspace(log_val - 2, log_val + 2, num=10)
    for n in np.linspace(0, max_n, steps):
        for beta in np.linspace(0, 1, 10):
            for dynamic_scale in scale_factors:
                val = D(n, beta, r, k, Omega, base, scale * dynamic_scale)
                diff = abs(val - value)
                candidates.append((diff, n, beta, dynamic_scale))
    best = min(candidates, key=lambda x: x[0])
    return best[1], best[2], best[3]

def parse_codata_ascii(filename):
    constants = []
    pattern = re.compile(r"^\s*(.*?)\s{2,}([0-9Ee\+\-\.]+)\s+([0-9Ee\+\-\.]+|exact)\s+(\S+)")
    with open(filename, "r") as f:
        for line in f:
            if line.startswith("Quantity") or line.strip() == "" or line.startswith("-"):
                continue
            m = pattern.match(line)
            if m:
                name, value_str, uncert_str, unit = m.groups()
                try:
                    value = float(value_str.replace("e", "E"))
                    uncertainty = None if uncert_str == "exact" else float(uncert_str.replace("e", "E"))
                    constants.append({
                        "name": name.strip(),
                        "value": value,
                        "uncertainty": uncertainty,
                        "unit": unit.strip()
                    })
                except:
                    continue
    return pd.DataFrame(constants)

def fit_single_constant(row, r, k, Omega, base, scale, max_n, steps):
    val = row['value']
    if val <= 0 or val > 1e50:
        return None
    try:
        n, beta, dynamic_scale = invert_D(val, r, k, Omega, base, scale, max_n, steps)
        approx = D(n, beta, r, k, Omega, base, scale * dynamic_scale)
        error = abs(val - approx)
        return {
            "name": row['name'],
            "value": val,
            "unit": row['unit'],
            "n": n,
            "beta": beta,
            "approx": approx,
            "error": error,
            "uncertainty": row['uncertainty'],
            "scale": dynamic_scale
        }
    except Exception as e:
        print(f"Failed inversion for {row['name']}: {e}")
        return None

def symbolic_fit_all_constants(df, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500, steps=500):
    results = Parallel(n_jobs=20)(
        delayed(fit_single_constant)(row, r, k, Omega, base, scale, max_n, steps)
        for _, row in df.iterrows()
    )
    return pd.DataFrame([r for r in results if r is not None])

def total_error(params, df):
    r, k, Omega, base, scale = params
    df_fit = symbolic_fit_all_constants(df, r=r, k=k, Omega=Omega, base=base, scale=scale, max_n=500, steps=500)
    threshold = np.percentile(df_fit['error'], 95)
    filtered = df_fit[df_fit['error'] <= threshold]
    rel_err = ((filtered['value'] - filtered['approx']) / filtered['value'])**2
    return rel_err.sum()

if __name__ == "__main__":
    print("Parsing CODATA constants from allascii.txt...")
    codata_df = parse_codata_ascii("allascii.txt")
    print(f"Parsed {len(codata_df)} constants.")

    # Use a subset for optimization
    subset_df = codata_df.head(20)
    init_params = [1.0, 1.0, 1.0, 2.0, 1.0]
    bounds = [(1e-5, 10), (1e-5, 10), (1e-5, 10), (1.5, 10), (1e-5, 100)]

    print("Optimizing symbolic model parameters...")
    res = minimize(total_error, init_params, args=(subset_df,), bounds=bounds, method='L-BFGS-B', options={'maxiter': 100})
    r_opt, k_opt, Omega_opt, base_opt, scale_opt = res.x
    print(f"Optimization complete. Found parameters:\nr = {r_opt:.6f}, k = {k_opt:.6f}, Omega = {Omega_opt:.6f}, base = {base_opt:.6f}, scale = {scale_opt:.6f}")

    print("Fitting symbolic dimensions to all constants...")
    fitted_df = symbolic_fit_all_constants(codata_df, r=r_opt, k=k_opt, Omega=Omega_opt, base=base_opt, scale=scale_opt, max_n=500, steps=500)
    fitted_df_sorted = fitted_df.sort_values("error")

    print("\nTop 20 best symbolic fits:")
    print(fitted_df_sorted.head(20).to_string(index=False))

    print("\nTop 20 worst symbolic fits:")
    print(fitted_df_sorted.tail(20).to_string(index=False))

    fitted_df_sorted.to_csv("symbolic_fit_results.txt", sep="\t", index=False)

    plt.figure(figsize=(10, 5))
    plt.hist(fitted_df_sorted['error'], bins=50, color='skyblue', edgecolor='black')
    plt.title('Histogram of Absolute Errors in Symbolic Fit')
    plt.xlabel('Absolute Error')
    plt.ylabel('Count')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.scatter(fitted_df_sorted['n'], fitted_df_sorted['error'], alpha=0.5, s=15, c='orange', edgecolors='black')
    plt.title('Absolute Error vs Symbolic Dimension n')
    plt.xlabel('n')
    plt.ylabel('Absolute Error')
    plt.grid(True)
    plt.tight_layout()
    plt.show()