#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
hdgl_p2p_node.py
P2P HDGL node with 131072 channels, batch Base4096 export,
OpenCL HMAC acceleration, TCP peer streaming, lazy-loading, and OpenGL folding.
"""

import sys, math, struct, json, hmac, hashlib, socket, threading, time, random
import numpy as np
from base4096 import encode
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileShader, compileProgram
import pyopencl as cl

# -------------------------------
# CONFIG
# -------------------------------
LATTICE_WIDTH = 1920
LATTICE_HEIGHT = 1080
CHANNELS = 131072
CHUNK_HEIGHT = LATTICE_HEIGHT // 24
PHI = 1.6180339887
PHI_POWERS = np.array([1.0/pow(PHI,7*(i+1)) for i in range(72)], dtype=np.float32)
THRESHOLD = math.sqrt(PHI)
MAX_SLOTS = 16_777_216
RECURSION_DEPTH = 3
BATCH_SIZE = 256          # Channels per network batch
NODE_CHANNELS = 256       # Channels this node perpetuates
HMAC_KEY = b"ZCHG-UltraHDGL-Key"
USE_OPENCL_HMAC = True
STREAM_PORT = 9999        # Listening port
PEERS = [("127.0.0.1", STREAM_PORT)]  # bootstrap peers

# -------------------------------
# OpenCL kernel (dummy HMAC)
# -------------------------------
OPENCL_KERNEL = """
__kernel void dummy_hmac(__global uchar *in_data, __global uchar *out_hash){
    int gid = get_global_id(0);
    out_hash[gid] = in_data[gid] ^ 0xAA;
}
"""

def init_opencl():
    platform = cl.get_platforms()[0]
    device = platform.get_devices()[0]
    ctx = cl.Context([device])
    queue = cl.CommandQueue(ctx)
    program = cl.Program(ctx, OPENCL_KERNEL).build()
    return ctx, queue, program

# -------------------------------
# Slot helpers
# -------------------------------
def flatten_indices_to_bytes(indices):
    arr = bytearray()
    for idx in indices:
        cp = int(idx) % 0x10FFFF
        if 0xD800 <= cp <= 0xDFFF: cp += 1
        arr.extend(cp.to_bytes(3,'big'))
    return bytes(arr)

# -------------------------------
# Batch generator
# -------------------------------
def generate_batch(start_ch, batch_size, num_samples=LATTICE_WIDTH, ctx=None, queue=None, program=None):
    batch_data = []
    for ch in range(start_ch, start_ch+batch_size):
        indices = np.arange(num_samples, dtype=np.uint32) + ch * MAX_SLOTS // CHANNELS
        data_bytes = flatten_indices_to_bytes(indices)

        # OpenCL HMAC
        if USE_OPENCL_HMAC and program:
            buf_in = cl.Buffer(ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR,
                               hostbuf=np.frombuffer(data_bytes,np.uint8))
            buf_out = cl.Buffer(ctx, cl.mem_flags.WRITE_ONLY, len(data_bytes))
            program.dummy_hmac(queue,(len(data_bytes),),None,buf_in,buf_out)
            result = np.empty_like(np.frombuffer(data_bytes,np.uint8))
            cl.enqueue_copy(queue,result,buf_out)
            queue.finish()
            hmac_digest = bytes(result[:32])
        else:
            hmac_digest = hmac.new(HMAC_KEY,data_bytes,hashlib.sha256).digest()

        encoded_data = encode(data_bytes)
        encoded_hmac = encode(hmac_digest)
        batch_data.append(f"{encoded_data}\n#HMAC:{encoded_hmac}\n")
    return ''.join(batch_data)

# -------------------------------
# P2P server/client
# -------------------------------
local_batches = {}  # channel_start -> batch string

def client_thread(conn, addr):
    try:
        # Simple protocol: receive requested start channel
        start_bytes = conn.recv(4)
        if len(start_bytes) < 4: return
        start_ch = struct.unpack("I", start_bytes)[0]
        if start_ch in local_batches:
            conn.sendall(local_batches[start_ch].encode('utf-8'))
    finally:
        conn.close()

def start_server():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(('0.0.0.0', STREAM_PORT))
    sock.listen()
    print(f"🚀 P2P server running on port {STREAM_PORT}")
    while True:
        conn, addr = sock.accept()
        threading.Thread(target=client_thread,args=(conn,addr),daemon=True).start()

def fetch_batch_from_peers(start_ch):
    for host, port in PEERS:
        try:
            with socket.create_connection((host,port),timeout=2) as s:
                s.sendall(struct.pack("I", start_ch))
                data = b''
                while True:
                    chunk = s.recv(8192)
                    if not chunk: break
                    data += chunk
                if data:
                    print(f"🔗 Fetched batch {start_ch} from {host}:{port}")
                    return data.decode('utf-8')
        except: continue
    return None

def get_batch(start_ch, ctx=None, queue=None, program=None):
    if start_ch in local_batches:
        return local_batches[start_ch]
    batch = fetch_batch_from_peers(start_ch)
    if batch: return batch
    # fallback: generate locally if node responsible
    if (start_ch % CHANNELS) < NODE_CHANNELS:
        batch = generate_batch(start_ch, BATCH_SIZE, ctx=ctx, queue=queue, program=program)
        local_batches[start_ch] = batch
        return batch
    return None

# -------------------------------
# OpenGL folding
# -------------------------------
omega_time = 0.0
shader = None
yOffset = 0
current_channel = 0
frame_count = 0
auto_scroll = True

VERTEX_SRC = """
#version 450 core
layout(location=0) in vec2 pos;
out vec2 texCoord;
void main(){ texCoord=(pos+1.0)*0.5; gl_Position=vec4(pos,0,1); }
"""

FRAGMENT_SRC = """
#version 450 core
in vec2 texCoord;
out vec4 fragColor;
uniform float omegaTime;
uniform float phiPowers[72];
uniform float threshold;
uniform int latticeWidth;
uniform int latticeHeight;
uniform int yOffset;
uniform int channelHighlight;
vec3 channelColors[24]=vec3[24](
vec3(1,0,0),vec3(0,1,0),vec3(0,0,1),
vec3(1,1,0),vec3(1,0,1),vec3(0,1,1),
vec3(0.5,0,0),vec3(0,0.5,0),vec3(0,0,0.5),
vec3(0.5,0.5,0),vec3(0.5,0,0.5),vec3(0,0.5,0.5),
vec3(0.25,0,0),vec3(0,0.25,0),vec3(0,0,0.25),
vec3(0.25,0.25,0),vec3(0.25,0,0.25),vec3(0,0.25,0.25),
vec3(0.75,0,0),vec3(0,0.75,0),vec3(0,0,0.75),
vec3(0.75,0.75,0),vec3(0.75,0,0.75),vec3(0,0.75,0.75)
);
float hash_float(int i,int seed){ uint ui=uint(i*374761393 + seed*668265263u); return float(ui & 0xFFFFFFFFu)/4294967295.0; }
vec3 computeVectorColor(int idx,float slot){ return channelColors[channelHighlight%24]*slot; }
float hdgl_slot(float val,float r_dim,float omega,int x,int y,int idx){
    float resonance=(x%4==0?0.1*sin(omegaTime*0.05+float(y)):0.0);
    float wave=(x%3==0?0.3:(x%3==1?0.0:-0.3));
    float omega_inst=phiPowers[y%72];
    float rec=r_dim*val*0.5+0.25*sin(omegaTime*r_dim+float(x));
    float new_val=val+omega_inst+resonance+wave+rec+omega*0.05;
    return new_val>threshold?1.0:0.0;
}
void main(){
    int x=int(texCoord.x*float(latticeWidth));
    int y=int(texCoord.y*float(latticeHeight))+yOffset;
    int idx=y*latticeWidth+x;
    float val=hash_float(idx,12345);
    float r_dim=0.3+0.01*float(y);
    float slot=hdgl_slot(val,r_dim,sin(omegaTime),x,y,idx);
    vec3 color=computeVectorColor(idx,slot);
    fragColor=vec4(color.rgb,1.0);
}
"""

def display():
    global omega_time, yOffset
    glClear(GL_COLOR_BUFFER_BIT)
    glUseProgram(shader)
    glUniform1f(glGetUniformLocation(shader,"omegaTime"),omega_time)
    glUniform1i(glGetUniformLocation(shader,"yOffset"),yOffset)
    glUniform1i(glGetUniformLocation(shader,"channelHighlight"),current_channel%24)
    omega_time += 0.01
    glBegin(GL_TRIANGLES)
    glVertex2f(-1,-1)
    glVertex2f(3,-1)
    glVertex2f(-1,3)
    glEnd()
    glutSwapBuffers()

def idle():
    global yOffset,current_channel,frame_count
    frame_count += 1
    if auto_scroll and frame_count%60==0:
        current_channel = (current_channel+1)%CHANNELS
    yOffset = (current_channel%24) * CHUNK_HEIGHT
    glutPostRedisplay()

def keyboard(key,x,y):
    global current_channel,auto_scroll
    if key==b'w': current_channel=(current_channel-1)%CHANNELS; auto_scroll=False
    elif key==b's': current_channel=(current_channel+1)%CHANNELS; auto_scroll=False
    elif key==b'a': auto_scroll=not auto_scroll

def init_gl():
    global shader
    vs = compileShader(VERTEX_SRC, GL_VERTEX_SHADER)
    fs = compileShader(FRAGMENT_SRC, GL_FRAGMENT_SHADER)
    shader = compileProgram(vs, fs)
    glUseProgram(shader)
    loc = glGetUniformLocation(shader,"phiPowers")
    glUniform1fv(loc,len(PHI_POWERS),PHI_POWERS)
    glUniform1f(glGetUniformLocation(shader,"threshold"),THRESHOLD)
    glUniform1i(glGetUniformLocation(shader,"latticeWidth"),LATTICE_WIDTH)
    glUniform1i(glGetUniformLocation(shader,"latticeHeight"),LATTICE_HEIGHT)

# -------------------------------
# MAIN
# -------------------------------
if __name__=="__main__":
    # Init OpenCL once
    ctx, queue, program = None, None, None
    if USE_OPENCL_HMAC:
        try: ctx, queue, program = init_opencl()
        except Exception as e: print(f"⚠️ OpenCL init failed: {e}")

    # Pre-generate local batches
    for start_ch in range(0, NODE_CHANNELS, BATCH_SIZE):
        local_batches[start_ch] = generate_batch(start_ch, BATCH_SIZE, ctx=ctx, queue=queue, program=program)
    print(f"✅ Pre-generated {len(local_batches)} batches locally")

    # Start P2P server
    threading.Thread(target=start_server,daemon=True).start()
    print("🖥 P2P node running, press Ctrl+C to exit.")

    # Start OpenGL
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE)
    glutInitWindowSize(1280,720)
    glutCreateWindow(b"HDGL P2P Node")
    init_gl()
    glutDisplayFunc(display)
    glutIdleFunc(idle)
    glutKeyboardFunc(keyboard)
    glutMainLoop()
