import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from tqdm import tqdm
from joblib import Parallel, delayed

# Extended primes list (up to 1000)
PRIMES = [
    2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
    73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
    157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233,
    239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317,
    331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419,
    421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503,
    509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607,
    613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701,
    709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811,
    821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997
]

phi = (1 + np.sqrt(5)) / 2
fib_cache = {}

def fib_real(n):
    if n in fib_cache:
        return fib_cache[n]
    from math import cos, pi, sqrt
    phi_inv = 1 / phi
    if n > 100:
        return 0.0
    term1 = phi**n / sqrt(5)
    term2 = (phi_inv**n) * cos(pi * n)
    result = term1 - term2
    fib_cache[n] = result
    return result

def D(n, beta, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0):
    Fn_beta = fib_real(n + beta)
    idx = int(np.floor(n + beta) + len(PRIMES)) % len(PRIMES)
    Pn_beta = PRIMES[idx]
    dyadic = base ** (n + beta)
    val = scale * phi * Fn_beta * dyadic * Pn_beta * Omega
    val = np.maximum(val, 1e-30)
    return np.sqrt(val) * (r ** k)

def invert_D(value, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500, steps=500):
    candidates = []
    log_val = np.log10(max(abs(value), 1e-30))
    scale_factors = np.logspace(log_val - 4, log_val + 4, num=20)
    max_n = min(5000, max(100, int(200 * log_val)))
    steps = min(3000, max(500, int(200 * log_val)))
    if log_val > 3:
        n_values = np.logspace(0, np.log10(max_n), steps)
    else:
        n_values = np.linspace(0, max_n, steps)
    for n in n_values:
        for beta in np.linspace(0, 1, 10):
            for dynamic_scale in scale_factors:
                val = D(n, beta, r, k, Omega, base, scale * dynamic_scale)
                diff = abs(val - value)
                candidates.append((diff, n, beta, dynamic_scale))
    candidates = sorted(candidates, key=lambda x: x[0])[:10]
    best = candidates[0]
    return best[1], best[2], best[3]

def parse_codata_ascii(filename):
    constants = []
    pattern = re.compile(r"^\s*(.*?)\s{2,}([0-9Ee\+\-\.]+)\s+([0-9Ee\+\-\.]+|exact)\s+(\S+)")
    with open(filename, "r") as f:
        for line in f:
            if line.startswith("Quantity") or line.strip() == "" or line.startswith("-"):
                continue
            m = pattern.match(line)
            if m:
                name, value_str, uncert_str, unit = m.groups()
                try:
                    value = float(value_str.replace("e", "E"))
                    uncertainty = None if uncert_str == "exact" else float(uncert_str.replace("e", "E"))
                    constants.append({
                        "name": name.strip(),
                        "value": value,
                        "uncertainty": uncertainty,
                        "unit": unit.strip()
                    })
                except:
                    continue
    return pd.DataFrame(constants)

def check_physical_consistency(df):
    bad_data = []
    # Mass ratio consistency (e.g., proton-electron mass ratio)
    proton_mass = df[df['name'] == 'proton mass']['value'].iloc[0] if 'proton mass' in df['name'].values else None
    electron_mass = df[df['name'] == 'electron mass']['value'].iloc[0] if 'electron mass' in df['name'].values else None
    proton_electron_ratio = df[df['name'] == 'proton-electron mass ratio']['value'].iloc[0] if 'proton-electron mass ratio' in df['name'].values else None
    if proton_mass and electron_mass and proton_electron_ratio:
        calc_ratio = proton_mass / electron_mass
        diff = abs(calc_ratio - proton_electron_ratio)
        uncert = df[df['name'] == 'proton-electron mass ratio']['uncertainty'].iloc[0]
        if uncert is not None and diff > 5 * uncert:
            bad_data.append({
                'name': 'proton-electron mass ratio',
                'value': proton_electron_ratio,
                'reason': f'Inconsistent with proton mass / electron mass (diff: {diff:.2e} > 5 * {uncert:.2e})'
            })
    # Speed of light vs. inverse meter-hertz relationship
    c = df[df['name'] == 'speed of light in vacuum']['value'].iloc[0] if 'speed of light in vacuum' in df['name'].values else None
    inv_m_hz = df[df['name'] == 'inverse meter-hertz relationship']['value'].iloc[0] if 'inverse meter-hertz relationship' in df['name'].values else None
    if c and inv_m_hz and abs(c - inv_m_hz) > 1e-6:
        bad_data.append({
            'name': 'inverse meter-hertz relationship',
            'value': inv_m_hz,
            'reason': f'Inconsistent with speed of light ({c:.2e} vs. {inv_m_hz:.2e})'
        })
    # Planck constant vs. reduced Planck constant
    h = df[df['name'] == 'Planck constant']['value'].iloc[0] if 'Planck constant' in df['name'].values else None
    h_bar = df[df['name'] == 'reduced Planck constant']['value'].iloc[0] if 'reduced Planck constant' in df['name'].values else None
    if h and h_bar and abs(h / (2 * np.pi) - h_bar) > 1e-10:
        bad_data.append({
            'name': 'reduced Planck constant',
            'value': h_bar,
            'reason': f'Inconsistent with Planck constant / (2π) ({h/(2*np.pi):.2e} vs. {h_bar:.2e})'
        })
    return bad_data

def fit_single_constant(row, r, k, Omega, base, scale, max_n, steps, error_threshold):
    val = row['value']
    if val <= 0 or val > 1e50:
        return None
    try:
        n, beta, dynamic_scale = invert_D(val, r, k, Omega, base, scale, max_n, steps)
        approx = D(n, beta, r, k, Omega, base, scale * dynamic_scale)
        error = abs(val - approx)
        rel_error = error / max(abs(val), 1e-30)
        log_val = np.log10(max(abs(val), 1e-30))
        max_n = min(5000, max(100, int(200 * log_val)))
        scale_factors = np.logspace(log_val - 4, log_val + 4, num=20)
        # Bad data detection
        bad_data = False
        bad_data_reason = []
        # Uncertainty check
        if row['uncertainty'] is not None:
            if row['uncertainty'] < 1e-10 or row['uncertainty'] > 0.1 * abs(val):
                bad_data = True
                bad_data_reason.append("Suspicious uncertainty")
        # Outlier check
        if error > error_threshold and row['uncertainty'] is not None and row['uncertainty'] < 1e-5 * abs(val):
            bad_data = True
            bad_data_reason.append("High error with low uncertainty")
        return {
            "name": row['name'],
            "value": val,
            "unit": row['unit'],
            "n": n,
            "beta": beta,
            "approx": approx,
            "error": error,
            "rel_error": rel_error,
            "uncertainty": row['uncertainty'],
            "scale": dynamic_scale,
            "bad_data": bad_data,
            "bad_data_reason": "; ".join(bad_data_reason) if bad_data_reason else ""
        }
    except Exception as e:
        print(f"Failed inversion for {row['name']}: {e}")
        return None

def symbolic_fit_all_constants(df, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500, steps=500):
    # Preliminary fit to get error threshold
    results = Parallel(n_jobs=20)(
        delayed(fit_single_constant)(row, r, k, Omega, base, scale, max_n, steps, np.inf)
        for _, row in df.iterrows()
    )
    results = [r for r in results if r is not None]
    df_results = pd.DataFrame(results)
    error_threshold = np.percentile(df_results['error'], 99) if not df_results.empty else np.inf
    # Final fit with error threshold
    results = Parallel(n_jobs=20)(
        delayed(fit_single_constant)(row, r, k, Omega, base, scale, max_n, steps, error_threshold)
        for _, row in df.iterrows()
    )
    results = [r for r in results if r is not None]
    df_results = pd.DataFrame(results)
    # Physical consistency check
    bad_data_physical = check_physical_consistency(df)
    for bad in bad_data_physical:
        df_results.loc[df_results['name'] == bad['name'], 'bad_data'] = True
        df_results.loc[df_results['name'] == bad['name'], 'bad_data_reason'] = (
            df_results.loc[df_results['name'] == bad['name'], 'bad_data_reason'] + "; " + bad['reason']
        ).str.strip("; ")
    # Uncertainty outlier check
    if not df_results.empty:
        log_values = np.log10(df_results['value'].abs().clip(1e-30))
        bins = pd.qcut(log_values, 5, duplicates='drop')
        for bin in bins.unique():
            mask = bins == bin
            if df_results[mask]['uncertainty'].notnull().any():
                median_uncert = df_results[mask]['uncertainty'].median()
                std_uncert = df_results[mask]['uncertainty'].std()
                if not np.isnan(std_uncert):
                    df_results.loc[mask & (df_results['uncertainty'] > median_uncert + 3 * std_uncert), 'bad_data'] = True
                    df_results.loc[mask & (df_results['uncertainty'] > median_uncert + 3 * std_uncert), 'bad_data_reason'] = (
                        df_results['bad_data_reason'] + "; Uncertainty outlier"
                    ).str.strip("; ")
    # Clear fib_cache
    global fib_cache
    if len(fib_cache) > 10000:
        fib_cache.clear()
    return df_results

def total_error(params, df):
    r, k, Omega, base, scale = params
    df_fit = symbolic_fit_all_constants(df, r=r, k=k, Omega=Omega, base=base, scale=scale, max_n=500, steps=500)
    threshold = np.percentile(df_fit['error'], 95)
    filtered = df_fit[df_fit['error'] <= threshold]
    rel_err = ((filtered['value'] - filtered['approx']) / filtered['value'])**2
    return rel_err.sum()

if __name__ == "__main__":
    print("Parsing CODATA constants from allascii.txt...")
    codata_df = parse_codata_ascii("allascii.txt")
    print(f"Parsed {len(codata_df)} constants.")

    # Use a larger subset for optimization
    subset_df = codata_df.head(50)
    init_params = [1.0, 1.0, 1.0, 2.0, 1.0]
    bounds = [(1e-5, 10), (1e-5, 10), (1e-5, 10), (1.5, 10), (1e-5, 100)]

    print("Optimizing symbolic model parameters...")
    res = minimize(total_error, init_params, args=(subset_df,), bounds=bounds, method='L-BFGS-B', options={'maxiter': 100})
    r_opt, k_opt, Omega_opt, base_opt, scale_opt = res.x
    print(f"Optimization complete. Found parameters:\nr = {r_opt:.6f}, k = {k_opt:.6f}, Omega = {Omega_opt:.6f}, base = {base_opt:.6f}, scale = {scale_opt:.6f}")

    print("Fitting symbolic dimensions to all constants...")
    fitted_df = symbolic_fit_all_constants(codata_df, r=r_opt, k=k_opt, Omega=Omega_opt, base=base_opt, scale=scale_opt, max_n=500, steps=500)
    fitted_df_sorted = fitted_df.sort_values("error")

    print("\nTop 20 best symbolic fits:")
    print(fitted_df_sorted.head(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nTop 20 worst symbolic fits:")
    print(fitted_df_sorted.tail(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nPotentially bad data constants summary:")
    bad_data_df = fitted_df[fitted_df['bad_data'] == True][['name', 'value', 'error', 'rel_error', 'uncertainty', 'bad_data_reason']]
    print(bad_data_df.to_string(index=False))

    fitted_df_sorted.to_csv("symbolic_fit_results.txt", sep="\t", index=False)

    plt.figure(figsize=(10, 5))
    plt.hist(fitted_df_sorted['error'], bins=50, color='skyblue', edgecolor='black')
    plt.title('Histogram of Absolute Errors in Symbolic Fit')
    plt.xlabel('Absolute Error')
    plt.ylabel('Count')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.scatter(fitted_df_sorted['n'], fitted_df_sorted['error'], alpha=0.5, s=15, c='orange', edgecolors='black')
    plt.title('Absolute Error vs Symbolic Dimension n')
    plt.xlabel('n')
    plt.ylabel('Absolute Error')
    plt.grid(True)
    plt.tight_layout()
    plt.show()