#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
hdgl_p2p_node_secure_final_v4.py
Secure P2P HDGL node with 131072 channels, independent per-channel movies, continuous HSV
"""

import sys, math, struct, hmac, hashlib, socket, threading, time
import numpy as np
from base4096 import encode
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileShader, compileProgram
import pyopencl as cl

# -------------------------------
# CONFIG
# -------------------------------
LATTICE_WIDTH = 1920
LATTICE_HEIGHT = 1080
CHANNELS = 131072
CHUNK_HEIGHT = max(1, LATTICE_HEIGHT // 512)
PHI = 1.6180339887
PHI_POWERS = np.array([1.0/pow(PHI,7*(i+1)) for i in range(72)], dtype=np.float32)
THRESHOLD = 0.5
MAX_SLOTS = 16_777_216
BATCH_SIZE = 256
NODE_CHANNELS = 256
MASTER_KEY = b"ZCHG-UltraHDGL-SuperKey"
USE_OPENCL_HMAC = True
STREAM_PORT = 9999
PEERS = [("127.0.0.1", STREAM_PORT)]

# -------------------------------
# Per-channel procedural state
# -------------------------------
channel_times = np.zeros(CHANNELS, dtype=np.float32)
channel_seeds = np.random.rand(CHANNELS).astype(np.float32)

# -------------------------------
# OpenCL kernel
# -------------------------------
OPENCL_KERNEL = """
__kernel void secure_hmac(__global uchar *in_data, __global uchar *salt, __global uchar *out_hash){
    int gid = get_global_id(0);
    out_hash[gid] = in_data[gid] ^ salt[gid % 32];
}
"""
cl_kernel_instance = None
def init_opencl():
    global cl_kernel_instance
    platform = cl.get_platforms()[0]
    device = platform.get_devices()[0]
    ctx = cl.Context([device])
    queue = cl.CommandQueue(ctx)
    program = cl.Program(ctx, OPENCL_KERNEL).build()
    cl_kernel_instance = cl.Kernel(program, "secure_hmac")
    return ctx, queue, cl_kernel_instance

# -------------------------------
# Slot helpers
# -------------------------------
def flatten_indices_to_bytes(indices):
    arr = bytearray()
    for idx in indices:
        cp = int(idx) % 0x10FFFF
        if 0xD800 <= cp <= 0xDFFF: cp += 1
        arr.extend(cp.to_bytes(3,'big'))
    return bytes(arr)

def fold_bytes(data):
    folded = bytearray(len(data))
    for i,b in enumerate(data):
        folded[i] = ((b ^ 0x5A) + ((i*13)%256)) & 0xFF
    return bytes(folded)

def derive_salt(channel_idx, batch_size):
    info = f"hdgl:channel:{channel_idx}:batch:{batch_size}".encode('utf-8')
    prk = hmac.new(MASTER_KEY, b"hdgl_salt_prk", hashlib.sha256).digest()
    return hmac.new(prk, info, hashlib.sha256).digest()[:32]

# -------------------------------
# Batch generator
# -------------------------------
def generate_batch(start_ch, batch_size, num_samples=LATTICE_WIDTH, ctx=None, queue=None, kernel=None):
    batch_data = []
    for ch in range(start_ch, start_ch+batch_size):
        indices = np.arange(num_samples, dtype=np.uint32) + ch * MAX_SLOTS // CHANNELS
        data_bytes = flatten_indices_to_bytes(indices)
        folded_bytes = fold_bytes(data_bytes)
        salt = derive_salt(ch, batch_size)
        if USE_OPENCL_HMAC and kernel:
            mf = cl.mem_flags
            buf_in = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=np.frombuffer(folded_bytes,np.uint8))
            buf_salt = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=np.frombuffer(salt,np.uint8))
            buf_out = cl.Buffer(ctx, mf.WRITE_ONLY, len(folded_bytes))
            kernel.set_arg(0, buf_in)
            kernel.set_arg(1, buf_salt)
            kernel.set_arg(2, buf_out)
            cl.enqueue_nd_range_kernel(queue, kernel, (len(folded_bytes),), None)
            result = np.empty_like(np.frombuffer(folded_bytes,np.uint8))
            cl.enqueue_copy(queue, result, buf_out)
            queue.finish()
            hmac_digest = bytes(result[:32])
        else:
            hmac_digest = hmac.new(salt, folded_bytes, hashlib.sha256).digest()
        encoded_data = encode(folded_bytes)
        encoded_hmac = encode(hmac_digest)
        batch_data.append(f"{encoded_data}\n#HMAC:{encoded_hmac}\n")
    return ''.join(batch_data)

# -------------------------------
# P2P server/client
# -------------------------------
local_batches = {}
def client_thread(conn, addr):
    try:
        start_bytes = conn.recv(4)
        if len(start_bytes) < 4: return
        start_ch = struct.unpack("I", start_bytes)[0]
        if start_ch in local_batches:
            conn.sendall(local_batches[start_ch].encode('utf-8'))
    finally:
        conn.close()

def start_server():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(('0.0.0.0', STREAM_PORT))
    sock.listen()
    print(f"🚀 Secure P2P server running on port {STREAM_PORT}")
    while True:
        conn, addr = sock.accept()
        threading.Thread(target=client_thread,args=(conn,addr),daemon=True).start()

def fetch_batch_from_peers(start_ch):
    for host, port in PEERS:
        try:
            with socket.create_connection((host,port),timeout=2) as s:
                s.sendall(struct.pack("I", start_ch))
                data = b''
                while True:
                    chunk = s.recv(8192)
                    if not chunk: break
                    data += chunk
                if data:
                    print(f"🔗 Fetched batch {start_ch} from {host}:{port}")
                    return data.decode('utf-8')
        except: continue
    return None

def get_batch(start_ch, ctx=None, queue=None, kernel=None):
    if start_ch in local_batches:
        return local_batches[start_ch]
    batch = fetch_batch_from_peers(start_ch)
    if batch: return batch
    if (start_ch % CHANNELS) < NODE_CHANNELS:
        batch = generate_batch(start_ch, BATCH_SIZE, ctx=ctx, queue=queue, kernel=kernel)
        local_batches[start_ch] = batch
        return batch
    return None

# -------------------------------
# OpenGL folding + per-channel movies
# -------------------------------
omega_time = 0.0
shader = None
yOffset = 0
current_channel = 0
frame_count = 0
auto_scroll = True

VERTEX_SRC = """
#version 330 core
layout(location=0) in vec2 pos;
out vec2 texCoord;
void main(){ texCoord=(pos+1.0)*0.5; gl_Position=vec4(pos,0,1); }
"""

FRAGMENT_SRC = """
#version 330 core
in vec2 texCoord;
out vec4 fragColor;

uniform float omegaTime;
uniform float phiPowers[72];
uniform float threshold;
uniform int latticeWidth;
uniform int latticeHeight;
uniform int yOffset;
uniform int channelHighlight;
uniform float channelTime;
uniform float channelSeed;

vec3 hsv2rgb(float h,float s,float v){
    float c=v*s;
    float x=c*(1.0-abs(mod(h*6.0,2.0)-1.0));
    float m=v-c;
    vec3 col;
    if(h<1.0/6.0){col=vec3(c,x,0.0);}
    else if(h<2.0/6.0){col=vec3(x,c,0.0);}
    else if(h<3.0/6.0){col=vec3(0.0,c,x);}
    else if(h<4.0/6.0){col=vec3(0.0,x,c);}
    else if(h<5.0/6.0){col=vec3(x,0.0,c);}
    else{col=vec3(c,0.0,x);}
    return col+m;
}

void main(){
    int x=int(texCoord.x*float(latticeWidth));
    int y=int(texCoord.y*float(latticeHeight))+yOffset;
    int idx=y*latticeWidth+x;

    float val=0.0;
    for(int i=0;i<72;i++){
        // multi-sine per-channel animation
        for(int j=0;j<3;j++){
            val+=sin(float(channelHighlight+j)*phiPowers[i]*2.0
                     + float(idx+j)*phiPowers[71-i]*2.0
                     + omegaTime*0.15
                     + channelSeed*(10.0+j)
                     + channelTime*(6.28+j*3.14));
        }
    }
    val=fract(val*0.5+0.5);
    float slot = smoothstep(threshold-0.15, threshold+0.15, val);

    float hue=mod(float(channelHighlight)/1024.0,1.0);
    vec3 color=hsv2rgb(hue,1.0,1.0)*slot;
    fragColor=vec4(color,1.0);
}
"""

def init_gl():
    global shader
    try:
        vs = compileShader(VERTEX_SRC, GL_VERTEX_SHADER)
        fs = compileShader(FRAGMENT_SRC, GL_FRAGMENT_SHADER)
        shader = compileProgram(vs, fs)
    except Exception as e:
        print(f"⚠️ Shader compilation failed: {e}")
        shader=None

def display():
    global omega_time, yOffset, current_channel
    glClear(GL_COLOR_BUFFER_BIT)
    if shader:
        glUseProgram(shader)
        glUniform1f(glGetUniformLocation(shader,"omegaTime"),omega_time)
        glUniform1i(glGetUniformLocation(shader,"yOffset"),yOffset)
        glUniform1i(glGetUniformLocation(shader,"latticeWidth"),LATTICE_WIDTH)
        glUniform1i(glGetUniformLocation(shader,"latticeHeight"),LATTICE_HEIGHT)
        glUniform1i(glGetUniformLocation(shader,"channelHighlight"),current_channel)
        glUniform1f(glGetUniformLocation(shader,"channelTime"),channel_times[current_channel])
        glUniform1f(glGetUniformLocation(shader,"channelSeed"),channel_seeds[current_channel])
        loc=glGetUniformLocation(shader,"phiPowers")
        glUniform1fv(loc,len(PHI_POWERS),PHI_POWERS)
        omega_time+=0.01

        glBegin(GL_TRIANGLES)
        glVertex2f(-1,-1)
        glVertex2f(3,-1)
        glVertex2f(-1,3)
        glEnd()
        glUseProgram(0)

    # Overlay: channel number + frame
    glColor3f(1.0,1.0,1.0)
    glWindowPos2i(10,10)
    text=f"Channel: {current_channel+1}/{CHANNELS}  Frame: {int(channel_times[current_channel]*60)}"
    for c in text:
        glutBitmapCharacter(GLUT_BITMAP_HELVETICA_18, ord(c))

    glutSwapBuffers()

def idle():
    global yOffset,current_channel,frame_count
    frame_count+=1
    channel_times[current_channel]+=0.016*5.0  # accelerated ~5x for perceptible motion
    if auto_scroll and frame_count%60==0:
        current_channel=(current_channel+1)%CHANNELS
    yOffset=(current_channel*CHUNK_HEIGHT)%LATTICE_HEIGHT
    glutPostRedisplay()

def keyboard(key,x,y):
    global current_channel,auto_scroll
    if key==b'w': current_channel=(current_channel-1)%CHANNELS; auto_scroll=False
    elif key==b's': current_channel=(current_channel+1)%CHANNELS; auto_scroll=False
    elif key==b'a': auto_scroll=not auto_scroll

# -------------------------------
# MAIN
# -------------------------------
if __name__=="__main__":
    ctx,queue,kernel=None,None,None
    if USE_OPENCL_HMAC:
        try: ctx,queue,kernel=init_opencl()
        except Exception as e: print(f"⚠️ OpenCL init failed: {e}")

    for start_ch in range(0,NODE_CHANNELS,BATCH_SIZE):
        local_batches[start_ch]=generate_batch(start_ch,BATCH_SIZE,ctx=ctx,queue=queue,kernel=kernel)
    print(f"✅ Pre-generated {len(local_batches)} secure batches locally")

    threading.Thread(target=start_server,daemon=True).start()
    print("🖥 Secure P2P node running, press Ctrl+C to exit.")

    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA|GLUT_DOUBLE)
    glutInitWindowSize(1280,720)
    glutCreateWindow(b"HDGL P2P Node Secure v4")
    init_gl()
    glutDisplayFunc(display)
    glutIdleFunc(idle)
    glutKeyboardFunc(keyboard)
    glutMainLoop()
