#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
hdgl_p2p_node_secure_final.py
Secure P2P HDGL node with 131072 channels, batch Base4096 export,
OpenCL HMAC acceleration, ephemeral per-channel salts, lazy-loading,
folded Base4096 bytes, and robust OpenGL folding.
"""

import sys, math, struct, hmac, hashlib, socket, threading, time
import numpy as np
from base4096 import encode
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileShader, compileProgram
import pyopencl as cl

# -------------------------------
# CONFIG
# -------------------------------
LATTICE_WIDTH = 1920
LATTICE_HEIGHT = 1080
CHANNELS = 131072
CHUNK_HEIGHT = LATTICE_HEIGHT // 24
PHI = 1.6180339887
PHI_POWERS = np.array([1.0/pow(PHI,7*(i+1)) for i in range(72)], dtype=np.float32)
THRESHOLD = math.sqrt(PHI)
MAX_SLOTS = 16_777_216
BATCH_SIZE = 256
NODE_CHANNELS = 256
MASTER_KEY = b"ZCHG-UltraHDGL-SuperKey"
USE_OPENCL_HMAC = True
STREAM_PORT = 9999
PEERS = [("127.0.0.1", STREAM_PORT)]

# -------------------------------
# OpenCL kernel (secure HMAC placeholder)
# -------------------------------
OPENCL_KERNEL = """
__kernel void secure_hmac(__global uchar *in_data, __global uchar *salt, __global uchar *out_hash){
    int gid = get_global_id(0);
    out_hash[gid] = in_data[gid] ^ salt[gid % 32];
}
"""

cl_kernel_instance = None  # reuse to avoid repeated retrieval

def init_opencl():
    global cl_kernel_instance
    platform = cl.get_platforms()[0]
    device = platform.get_devices()[0]
    ctx = cl.Context([device])
    queue = cl.CommandQueue(ctx)
    program = cl.Program(ctx, OPENCL_KERNEL).build()
    cl_kernel_instance = cl.Kernel(program, "secure_hmac")
    return ctx, queue, cl_kernel_instance

# -------------------------------
# Slot helpers
# -------------------------------
def flatten_indices_to_bytes(indices):
    arr = bytearray()
    for idx in indices:
        cp = int(idx) % 0x10FFFF
        if 0xD800 <= cp <= 0xDFFF: cp += 1
        arr.extend(cp.to_bytes(3,'big'))
    return bytes(arr)

def fold_bytes(data):
    folded = bytearray(len(data))
    for i,b in enumerate(data):
        folded[i] = ((b ^ 0x5A) + ((i*13)%256)) & 0xFF
    return bytes(folded)

def derive_salt(channel_idx, batch_size):
    info = f"hdgl:channel:{channel_idx}:batch:{batch_size}".encode('utf-8')
    prk = hmac.new(MASTER_KEY, b"hdgl_salt_prk", hashlib.sha256).digest()
    return hmac.new(prk, info, hashlib.sha256).digest()[:32]

# -------------------------------
# Batch generator
# -------------------------------
def generate_batch(start_ch, batch_size, num_samples=LATTICE_WIDTH, ctx=None, queue=None, kernel=None):
    batch_data = []
    for ch in range(start_ch, start_ch+batch_size):
        indices = np.arange(num_samples, dtype=np.uint32) + ch * MAX_SLOTS // CHANNELS
        data_bytes = flatten_indices_to_bytes(indices)
        folded_bytes = fold_bytes(data_bytes)
        salt = derive_salt(ch, batch_size)

        if USE_OPENCL_HMAC and kernel:
            mf = cl.mem_flags
            buf_in = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=np.frombuffer(folded_bytes,np.uint8))
            buf_salt = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=np.frombuffer(salt,np.uint8))
            buf_out = cl.Buffer(ctx, mf.WRITE_ONLY, len(folded_bytes))
            kernel.set_arg(0, buf_in)
            kernel.set_arg(1, buf_salt)
            kernel.set_arg(2, buf_out)
            cl.enqueue_nd_range_kernel(queue, kernel, (len(folded_bytes),), None)
            result = np.empty_like(np.frombuffer(folded_bytes,np.uint8))
            cl.enqueue_copy(queue, result, buf_out)
            queue.finish()
            hmac_digest = bytes(result[:32])
        else:
            hmac_digest = hmac.new(salt, folded_bytes, hashlib.sha256).digest()

        encoded_data = encode(folded_bytes)
        encoded_hmac = encode(hmac_digest)
        batch_data.append(f"{encoded_data}\n#HMAC:{encoded_hmac}\n")
    return ''.join(batch_data)

# -------------------------------
# P2P server/client
# -------------------------------
local_batches = {}

def client_thread(conn, addr):
    try:
        start_bytes = conn.recv(4)
        if len(start_bytes) < 4: return
        start_ch = struct.unpack("I", start_bytes)[0]
        if start_ch in local_batches:
            conn.sendall(local_batches[start_ch].encode('utf-8'))
    finally:
        conn.close()

def start_server():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(('0.0.0.0', STREAM_PORT))
    sock.listen()
    print(f"🚀 Secure P2P server running on port {STREAM_PORT}")
    while True:
        conn, addr = sock.accept()
        threading.Thread(target=client_thread,args=(conn,addr),daemon=True).start()

def fetch_batch_from_peers(start_ch):
    for host, port in PEERS:
        try:
            with socket.create_connection((host,port),timeout=2) as s:
                s.sendall(struct.pack("I", start_ch))
                data = b''
                while True:
                    chunk = s.recv(8192)
                    if not chunk: break
                    data += chunk
                if data:
                    print(f"🔗 Fetched batch {start_ch} from {host}:{port}")
                    return data.decode('utf-8')
        except: continue
    return None

def get_batch(start_ch, ctx=None, queue=None, kernel=None):
    if start_ch in local_batches:
        return local_batches[start_ch]
    batch = fetch_batch_from_peers(start_ch)
    if batch: return batch
    if (start_ch % CHANNELS) < NODE_CHANNELS:
        batch = generate_batch(start_ch, BATCH_SIZE, ctx=ctx, queue=queue, kernel=kernel)
        local_batches[start_ch] = batch
        return batch
    return None

# -------------------------------
# OpenGL folding
# -------------------------------
omega_time = 0.0
shader = None
yOffset = 0
current_channel = 0
frame_count = 0
auto_scroll = True

VERTEX_SRC = """
#version 330 core
layout(location=0) in vec2 pos;
out vec2 texCoord;
void main(){ texCoord=(pos+1.0)*0.5; gl_Position=vec4(pos,0,1); }
"""

FRAGMENT_SRC = """
#version 330 core
in vec2 texCoord;
out vec4 fragColor;

uniform float omegaTime;
uniform float phiPowers[72];
uniform float threshold;
uniform int latticeWidth;
uniform int latticeHeight;
uniform int yOffset;
uniform int channelHighlight;

// Base 24 colors
vec3 channelColors[24] = vec3[24](
    vec3(1,0,0), vec3(0,1,0), vec3(0,0,1),
    vec3(1,1,0), vec3(1,0,1), vec3(0,1,1),
    vec3(0.5,0,0), vec3(0,0.5,0), vec3(0,0,0.5),
    vec3(0.5,0.5,0), vec3(0.5,0,0.5), vec3(0,0.5,0.5),
    vec3(0.25,0,0), vec3(0,0.25,0), vec3(0,0,0.25),
    vec3(0.25,0.25,0), vec3(0.25,0,0.25), vec3(0,0.25,0.25),
    vec3(0.75,0,0), vec3(0,0.75,0), vec3(0,0,0.75),
    vec3(0.75,0.75,0), vec3(0.75,0,0.75), vec3(0,0.75,0.75)
);

void main() {
    int x = int(texCoord.x * float(latticeWidth));
    int y = int(texCoord.y * float(latticeHeight)) + yOffset;
    int idx = y * latticeWidth + x;

    // Distinct pattern per channel
    float chPattern = sin(float(channelHighlight)*0.1 + float(idx)*0.01 + omegaTime*2.0);
    float slot = step(threshold, fract(chPattern));

    int ch = channelHighlight % 24;
    vec3 color = channelColors[ch] * slot;

    fragColor = vec4(color,1.0);
}
"""

def init_gl():
    global shader
    try:
        vs = compileShader(VERTEX_SRC, GL_VERTEX_SHADER)
        fs = compileShader(FRAGMENT_SRC, GL_FRAGMENT_SHADER)
        shader = compileProgram(vs, fs)
    except Exception as e:
        print(f"⚠️ Shader compilation failed: {e}")
        shader = None

def display():
    global omega_time, yOffset, current_channel
    glClear(GL_COLOR_BUFFER_BIT)
    if shader:
        glUseProgram(shader)
        glUniform1f(glGetUniformLocation(shader,"omegaTime"),omega_time)
        glUniform1i(glGetUniformLocation(shader,"yOffset"),yOffset)
        glUniform1i(glGetUniformLocation(shader,"latticeWidth"),LATTICE_WIDTH)
        glUniform1i(glGetUniformLocation(shader,"latticeHeight"),LATTICE_HEIGHT)
        glUniform1i(glGetUniformLocation(shader,"channelHighlight"),current_channel)
        omega_time += 0.01

        # Fullscreen triangle
        glBegin(GL_TRIANGLES)
        glVertex2f(-1,-1)
        glVertex2f(3,-1)
        glVertex2f(-1,3)
        glEnd()

    # Overlay current channel number
    glColor3f(1.0, 1.0, 1.0)
    glWindowPos2i(10, 10)
    for c in f"Channel: {current_channel}":
        glutBitmapCharacter(GLUT_BITMAP_HELVETICA_18, ord(c))

    glutSwapBuffers()

def idle():
    global yOffset,current_channel,frame_count
    frame_count += 1
    if auto_scroll and frame_count%60==0:
        current_channel=(current_channel+1)%CHANNELS
    yOffset=(current_channel%24)*CHUNK_HEIGHT
    glutPostRedisplay()

def keyboard(key,x,y):
    global current_channel,auto_scroll
    if key==b'w': current_channel=(current_channel-1)%CHANNELS; auto_scroll=False
    elif key==b's': current_channel=(current_channel+1)%CHANNELS; auto_scroll=False
    elif key==b'a': auto_scroll=not auto_scroll

# -------------------------------
# MAIN
# -------------------------------
if __name__=="__main__":
    ctx, queue, kernel = None, None, None
    if USE_OPENCL_HMAC:
        try: ctx, queue, kernel = init_opencl()
        except Exception as e: print(f"⚠️ OpenCL init failed: {e}")

    for start_ch in range(0, NODE_CHANNELS, BATCH_SIZE):
        local_batches[start_ch]=generate_batch(start_ch,BATCH_SIZE,ctx=ctx,queue=queue,kernel=kernel)
    print(f"✅ Pre-generated {len(local_batches)} secure batches locally")

    threading.Thread(target=start_server,daemon=True).start()
    print("🖥 Secure P2P node running, press Ctrl+C to exit.")

    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA|GLUT_DOUBLE)
    glutInitWindowSize(1280,720)
    glutCreateWindow(b"HDGL P2P Node Secure")
    init_gl()
    glutDisplayFunc(display)
    glutIdleFunc(idle)
    glutKeyboardFunc(keyboard)
    glutMainLoop()
