import socket
import threading
import pickle
import struct
import time
import numpy as np

# -------------------------------
# P2P CONFIG
# -------------------------------
PEER_PORT = 5005
DISCOVERY_PORT = 5006
PEER_LIST = set()  # dynamically discovered peers
SLICE_BROADCAST_INTERVAL = 2.0
DISCOVERY_INTERVAL = 5.0

# -------------------------------
# SLICE OBJECT
# -------------------------------
class LatticeSlice:
    def __init__(self, start_idx, rows):
        self.start_idx = start_idx
        self.rows = rows
        self.version = 0  # incremented each update

# -------------------------------
# PHI-WEIGHTED TRANSFORM
# -------------------------------
def phi_weighted_transform(slice_data):
    rows = slice_data.shape[0]
    transformed = np.empty_like(slice_data)
    for y in range(rows):
        phi_factor = PHI_POWERS[y % len(PHI_POWERS)]
        transformed[y] = slice_data[y] * phi_factor
    return transformed

# -------------------------------
# PEER DISCOVERY (UDP)
# -------------------------------
def discovery_broadcast_loop():
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
    while True:
        msg = b"HDGL_DISCOVERY"
        sock.sendto(msg, ('<broadcast>', DISCOVERY_PORT))
        time.sleep(DISCOVERY_INTERVAL)

def discovery_listener():
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.bind(('', DISCOVERY_PORT))
    while True:
        data, addr = sock.recvfrom(1024)
        if data == b"HDGL_DISCOVERY":
            PEER_LIST.add((addr[0], PEER_PORT))
            # optional: send acknowledgement
            # sock.sendto(b"ACK", addr)

# -------------------------------
# P2P SEND / RECEIVE
# -------------------------------
def send_slice_to_peer(peer_addr, lattice_slice):
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect(peer_addr)
        slice_data = phi_weighted_transform(lattice_slice.rows)
        payload = pickle.dumps({
            "start_idx": lattice_slice.start_idx,
            "version": lattice_slice.version,
            "rows": slice_data
        })
        s.sendall(struct.pack(">I", len(payload)) + payload)
        s.close()
    except Exception as e:
        print(f"[P2P] Failed to send slice to {peer_addr}: {e}")

def p2p_listener(lattice_slice_obj):
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind(("", PEER_PORT))
    server.listen(5)
    print(f"[P2P] Listening on TCP {PEER_PORT}...")
    while True:
        client, addr = server.accept()
        try:
            length_bytes = client.recv(4)
            if not length_bytes: continue
            length = struct.unpack(">I", length_bytes)[0]
            data = b""
            while len(data) < length:
                packet = client.recv(length - len(data))
                if not packet: break
                data += packet
            payload = pickle.loads(data)
            incoming_version = payload["version"]
            start_idx = payload["start_idx"]
            incoming_rows = payload["rows"]
            if incoming_version > lattice_slice_obj.version:
                # merge
                rel_start = start_idx - lattice_slice_obj.start_idx
                for y in range(incoming_rows.shape[0]):
                    target_row = rel_start + y
                    if 0 <= target_row < lattice_slice_obj.rows.shape[0]:
                        lattice_slice_obj.rows[target_row] += incoming_rows[y]
                lattice_slice_obj.version = incoming_version
                print(f"[P2P] Merged slice from {addr}, version {incoming_version}")
        except Exception as e:
            print(f"[P2P] Error from {addr}: {e}")
        finally:
            client.close()

# -------------------------------
# BROADCAST LOOP
# -------------------------------
def p2p_broadcast_loop(lattice_slice_obj):
    while True:
        lattice_slice_obj.version += 1
        for peer in list(PEER_LIST):
            threading.Thread(target=send_slice_to_peer, args=(peer, lattice_slice_obj), daemon=True).start()
        time.sleep(SLICE_BROADCAST_INTERVAL)

# -------------------------------
# CATCH-UP FOR NEW PEERS
# -------------------------------
def request_latest_slice(peer_addr, lattice_slice_obj):
    """Optional: request the latest wavefront slice from a peer"""
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect(peer_addr)
        s.sendall(b"REQUEST_SLICE")
        length_bytes = s.recv(4)
        length = struct.unpack(">I", length_bytes)[0]
        data = b""
        while len(data) < length:
            data += s.recv(length - len(data))
        payload = pickle.loads(data)
        lattice_slice_obj.rows = payload["rows"]
        lattice_slice_obj.start_idx = payload["start_idx"]
        lattice_slice_obj.version = payload["version"]
        print(f"[P2P] Synced latest slice from {peer_addr}")
        s.close()
    except Exception as e:
        print(f"[P2P] Failed to sync from {peer_addr}: {e}")

# -------------------------------
# START P2P
# -------------------------------
def start_p2p(lattice_slice_obj):
    threading.Thread(target=discovery_broadcast_loop, daemon=True).start()
    threading.Thread(target=discovery_listener, daemon=True).start()
    threading.Thread(target=p2p_listener, args=(lattice_slice_obj,), daemon=True).start()
    threading.Thread(target=p2p_broadcast_loop, args=(lattice_slice_obj,), daemon=True).start()
