import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from astropy.cosmology import FlatLambdaCDM

# Golden ratio constant φ
phi = (1 + np.sqrt(5)) / 2

# Symbolic parameters from compression results:
# Each param = prime * φ^(power)
symbolic_params = {
    'k':     {'power': -1.340, 'prime': 2},
    'r0':    {'power': -1.340, 'prime': 2},
    'Omega0':{'power': -1.340, 'prime': 2},
    's0':    {'power': -1.452, 'prime': 2},
    'alpha': {'power': -3.682, 'prime': 1},
    'beta':  {'power': -4.401, 'prime': 3},
    'gamma': {'power': -2.296, 'prime': 3},
}

def reconstruct_param(power, prime):
    """Reconstruct parameter numerically from symbolic phi power and prime."""
    return prime * (phi ** power)

# Reconstruct all parameters, printing for communication
params = {}
print("Symbolic parameter reconstruction:")
for name, vals in symbolic_params.items():
    val = reconstruct_param(vals['power'], vals['prime'])
    params[name] = val
    print(f"  {name:<7}: prime={vals['prime']:>2}, power={vals['power']:>7.3f} => {val:.6f}")

# Load supernova data
filename = 'hlsp_ps1cosmo_panstarrs_gpc1_all_model_v1_lcparam-full.txt'
lc_data = np.genfromtxt(
    filename,
    delimiter=' ',
    names=True,
    comments='#',
    dtype=None,
    encoding=None
)

z = lc_data['zcmb']
mb = lc_data['mb']
dmb = lc_data['dmb']
M = -19.3
mu_obs = mb - M

# Constants that remain fixed references but could also be explored as emergent in future
H0 = 70.0        # Hubble constant (km/s/Mpc)
c0 = 299792.458  # Speed of light (km/s)

# Cosmological scaling functions using symbolic emergent params
def a_of_z(z):
    return 1 / (1 + z)

def Omega(z, Omega0, alpha):
    return Omega0 / (a_of_z(z) ** alpha)

def s(z, s0, beta):
    return s0 * (1 + z) ** (-beta)

def G(z, k, r0, Omega0, s0, alpha, beta):
    return Omega(z, Omega0, alpha) * k**2 * r0 / s(z, s0, beta)

def H(z, k, r0, Omega0, s0, alpha, beta):
    Om_m = 0.3
    Om_de = 0.7
    Gz = G(z, k, r0, Omega0, s0, alpha, beta)
    Hz_sq = (H0 ** 2) * (Om_m * Gz * (1 + z) ** 3 + Om_de)
    return np.sqrt(Hz_sq)

def emergent_c(z, Omega0, alpha, gamma):
    """Emergent speed of light scaling."""
    return c0 * (Omega(z, Omega0, alpha) / Omega0) ** gamma

def compute_luminosity_distance_grid(z_max, params, n=500):
    k, r0, Omega0, s0, alpha, beta, gamma = params
    z_grid = np.linspace(0, z_max, n)
    c_z = emergent_c(z_grid, Omega0, alpha, gamma)
    integrand_values = c_z / H(z_grid, k, r0, Omega0, s0, alpha, beta)
    integral_grid = np.cumsum((integrand_values[:-1] + integrand_values[1:]) / 2 * np.diff(z_grid))
    integral_grid = np.insert(integral_grid, 0, 0)
    d_c = interp1d(z_grid, integral_grid, kind='cubic', fill_value="extrapolate")
    def d_L(z):
        return (1 + z) * d_c(z)
    return d_L

def model_mu(z_arr, params):
    d_L_func = compute_luminosity_distance_grid(np.max(z_arr), params)
    d_L_vals = d_L_func(z_arr)
    return 5 * np.log10(d_L_vals) + 25

# Pack parameters as list for model functions
param_list = [
    params['k'],
    params['r0'],
    params['Omega0'],
    params['s0'],
    params['alpha'],
    params['beta'],
    params['gamma'],
]

# Compute fit and residuals using symbolic emergent parameters
mu_fit = model_mu(z, param_list)
residuals = mu_obs - mu_fit

# Print summary for communication
print("\nFit evaluation with symbolic emergent parameters:")
print(f"Minimum residual: {np.min(residuals):.4f}")
print(f"Maximum residual: {np.max(residuals):.4f}")
print(f"Residual RMS    : {np.sqrt(np.mean(residuals**2)):.4f}")

# Plot fit vs data
plt.figure(figsize=(10, 6))
plt.errorbar(z, mu_obs, yerr=dmb, fmt='.', alpha=0.5, label='Pan-STARRS1 SNe')
plt.plot(z, mu_fit, 'r-', label='Symbolic Emergent Gravity Model')
plt.xlabel('Redshift (z)')
plt.ylabel('Distance Modulus (μ)')
plt.title('Supernova Distance Modulus using Symbolic Parameters')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot residuals
plt.figure(figsize=(10, 4))
plt.errorbar(z, residuals, yerr=dmb, fmt='.', alpha=0.5)
plt.axhline(0, color='red', linestyle='--')
plt.xlabel('Redshift (z)')
plt.ylabel('Residuals (μ_data - μ_model)')
plt.title('Residuals of Symbolic Model')
plt.grid(True)
plt.tight_layout()
plt.show()
