import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import quad
from scipy.optimize import minimize

# Load your supernova data (same as before)
filename = 'hlsp_ps1cosmo_panstarrs_gpc1_all_model_v1_lcparam-full.txt'
lc_data = np.genfromtxt(
    filename,
    delimiter=' ',
    names=True,
    comments='#',
    dtype=None,
    encoding=None
)

z = lc_data['zcmb']
mb = lc_data['mb']
dmb = lc_data['dmb']
M = -19.3
mu_obs = mb - M

# Constants
H0 = 70.0        # Hubble constant km/s/Mpc
c = 299792.458   # speed of light km/s

# Define model functions with parameters vector p:
# p = [k, r0, Omega0, s0, alpha, beta]

def a_of_z(z):
    return 1 / (1 + z)

def Omega(z, Omega0, alpha):
    return Omega0 / (a_of_z(z) ** alpha)

def s(z, s0, beta):
    return s0 * (1 + z) ** (-beta)

def G(z, k, r0, Omega0, s0, alpha, beta):
    return Omega(z, Omega0, alpha) * k**2 * r0 / s(z, s0, beta)

def H(z, k, r0, Omega0, s0, alpha, beta):
    Om_m = 0.3
    Om_de = 0.7
    Gz = G(z, k, r0, Omega0, s0, alpha, beta)
    Hz_sq = (H0 ** 2) * (Om_m * Gz * (1 + z) ** 3 + Om_de)
    return np.sqrt(Hz_sq)

def integrand(z_prime, k, r0, Omega0, s0, alpha, beta):
    return c / H(z_prime, k, r0, Omega0, s0, alpha, beta)

def luminosity_distance(z_val, k, r0, Omega0, s0, alpha, beta):
    integral, _ = quad(integrand, 0, z_val, args=(k, r0, Omega0, s0, alpha, beta))
    return (1 + z_val) * integral

def distance_modulus(z_val, k, r0, Omega0, s0, alpha, beta):
    d_L = luminosity_distance(z_val, k, r0, Omega0, s0, alpha, beta)
    return 5 * np.log10(d_L) + 25

def model_mu(z_arr, params):
    k, r0, Omega0, s0, alpha, beta = params
    return np.array([distance_modulus(zi, k, r0, Omega0, s0, alpha, beta) for zi in z_arr])

# Define chi-squared cost function
def chi_squared(params, z_arr, mu_obs, mu_err):
    mu_model = model_mu(z_arr, params)
    chi2 = np.sum(((mu_obs - mu_model) / mu_err) ** 2)
    return chi2

# Initial guess for parameters
p0 = [1.0, 1.0, 1.0, 1.0, 3.0, 1.0]

# Bounds for parameters (to keep physically reasonable)
bounds = [(0.1, 10),   # k
          (0.1, 10),   # r0
          (0.1, 10),   # Omega0
          (0.1, 10),   # s0
          (0.1, 10),   # alpha
          (0.1, 10)]   # beta

# Run the minimizer
result = minimize(chi_squared, p0, args=(z, mu_obs, dmb), bounds=bounds, method='L-BFGS-B')

print("Fit success:", result.success)
print("Best-fit parameters:", result.x)
print("Minimum chi2:", result.fun)

# Plot best fit
mu_fit = model_mu(z, result.x)

plt.errorbar(z, mu_obs, yerr=dmb, fmt='.', alpha=0.5, label='Pan-STARRS1 SNe')
plt.plot(z, mu_fit, 'r-', label='Best-fit Emergent Gravity Model')
plt.xlabel('Redshift (z)')
plt.ylabel('Distance Modulus (μ)')
plt.legend()
plt.grid(True)
plt.title('Supernova Data and Emergent Gravity Model Fit')
plt.show()
