import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import quad
from scipy.optimize import minimize
from scipy.interpolate import interp1d

# Load supernova data
filename = 'hlsp_ps1cosmo_panstarrs_gpc1_all_model_v1_lcparam-full.txt'

lc_data = np.genfromtxt(
    filename,
    delimiter=' ',
    names=True,
    comments='#',
    dtype=None,
    encoding=None
)

z = lc_data['zcmb']
mb = lc_data['mb']
dmb = lc_data['dmb']
M = -19.3
mu_obs = mb - M

# Constants
H0 = 70.0        # km/s/Mpc
c = 299792.458   # speed of light km/s

# Cosmological functions
def a_of_z(z):
    return 1 / (1 + z)

def Omega(z, Omega0, alpha):
    return Omega0 / (a_of_z(z) ** alpha)

def s(z, s0, beta):
    return s0 * (1 + z) ** (-beta)

def G(z, k, r0, Omega0, s0, alpha, beta):
    return Omega(z, Omega0, alpha) * k**2 * r0 / s(z, s0, beta)

def H(z, k, r0, Omega0, s0, alpha, beta):
    Om_m = 0.3
    Om_de = 0.7
    Gz = G(z, k, r0, Omega0, s0, alpha, beta)
    Hz_sq = (H0 ** 2) * (Om_m * Gz * (1 + z) ** 3 + Om_de)
    return np.sqrt(Hz_sq)

def emergent_c(z, Omega0, alpha, gamma=1.0):
    return c * (Omega(z, Omega0, alpha) / Omega0) ** gamma

def compute_luminosity_distance_grid(z_max, params, gamma=1.0, n=500):
    k, r0, Omega0, s0, alpha, beta = params
    z_grid = np.linspace(0, z_max, n)
    
    # Emergent c(z)
    c_z = emergent_c(z_grid, Omega0, alpha, gamma)
    
    integrand_values = c_z / H(z_grid, k, r0, Omega0, s0, alpha, beta)
    integral_grid = np.cumsum((integrand_values[:-1] + integrand_values[1:]) / 2 * np.diff(z_grid))
    integral_grid = np.insert(integral_grid, 0, 0)
    
    d_c = interp1d(z_grid, integral_grid, kind='cubic', fill_value="extrapolate")
    def d_L(z):
        return (1 + z) * d_c(z)
    return d_L


def model_mu(z_arr, params):
    d_L_func = compute_luminosity_distance_grid(np.max(z_arr), params)
    d_L_vals = d_L_func(z_arr)
    return 5 * np.log10(d_L_vals) + 25

def chi_squared(params, z_arr, mu_obs, mu_err):
    mu_model = model_mu(z_arr, params)
    chi2 = np.sum(((mu_obs - mu_model) / mu_err) ** 2)
    return chi2

# Initial guess
p0 = [1.0, 1.0, 1.0, 1.0, 3.0, 1.0]

# Expanded bounds
bounds = [(0.01, 10), (0.01, 10), (0.01, 10), (0.01, 10), (0.01, 10), (0.01, 10)]

# Run optimizer
result = minimize(chi_squared, p0, args=(z, mu_obs, dmb), bounds=bounds, method='L-BFGS-B')

print("Fit success:", result.success)
print("Best-fit parameters:", result.x)
print("Minimum chi2:", result.fun)

# Calculate best-fit model and residuals
mu_fit = model_mu(z, result.x)
residuals = mu_obs - mu_fit

# Plot results
plt.figure(figsize=(10, 6))
plt.errorbar(z, mu_obs, yerr=dmb, fmt='.', alpha=0.5, label='Pan-STARRS1 SNe')
plt.plot(z, mu_fit, 'r-', label='Best-fit Emergent Gravity Model')
plt.xlabel('Redshift (z)')
plt.ylabel('Distance Modulus (μ)')
plt.title('Supernova Data and Model Fit')
plt.legend()
plt.grid(True)
plt.show()

# Residuals plot
plt.figure(figsize=(10, 4))
plt.errorbar(z, residuals, yerr=dmb, fmt='.', alpha=0.5)
plt.axhline(0, color='red', linestyle='--')
plt.xlabel('Redshift (z)')
plt.ylabel('Residuals (μ_data - μ_model)')
plt.title('Residuals of the Fit')
plt.grid(True)
plt.show()
