#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
tv10_ultrabatch_gl.py
HDGL recursive node with 131,072 channels,
ultra-batch OpenCL HMAC, streaming-safe Base4096 export,
and OpenGL real-time folding.
"""

import sys, math, struct, json, unicodedata, hmac, hashlib
import numpy as np
from base4096 import encode
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileShader, compileProgram
import pyopencl as cl
from concurrent.futures import ThreadPoolExecutor
import pygame
import scipy.io.wavfile as wavfile
import os
import tempfile
import time
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
from moviepy.editor import VideoFileClip

# -------------------------------
# CONFIG
# -------------------------------
LATTICE_WIDTH = 1920
LATTICE_HEIGHT = 1080
CHANNELS = 131_072
SAMPLES_PER_CHANNEL = 32
PHI = 1.6180339887
PHI_POWERS = np.array([1.0 / pow(PHI, 7*(i+1)) for i in range(72)], dtype=np.float32)
THRESHOLD = math.sqrt(PHI)
MAX_SLOTS = 16_777_216
USE_OPENCL_HMAC = True
CHUNK_HEIGHT = LATTICE_HEIGHT // 24

EXPORT_JSON = "hdgl_vectors.json"
EXPORT_BINARY = "hdgl_lattice.hdgl"
EXPORT_BASE4096 = "vectors_ultrabatch.b4096"
HMAC_KEY = b"ZCHG-Base4096-Signature-Key"

# -------------------------------
# CHAR SLOT HELPERS
# -------------------------------
def hdgl_char(idx):
    h = (idx * 2654435761) % 0x110000
    c = chr(h) if 0xD800 > h or h > 0xDFFF else chr((h+1)%0x110000)
    return c

# -------------------------------
# RECURSIVE VECTORS
# -------------------------------
def unfold_slot(idx, depth=0):
    val = (idx * 2654435761) % 4096 / 4096.0
    slot = {"idx": idx, "value": val, "char": hdgl_char(idx), "children": []}
    if depth < 3:
        for offset in [1,2]:
            child_idx = (idx*offset + depth*1337) % MAX_SLOTS
            slot["children"].append(unfold_slot(child_idx, depth+1))
    return slot

def build_recursive_vectors(num_samples=SAMPLES_PER_CHANNEL):
    return [unfold_slot(idx) for idx in range(num_samples)]

def flatten_indices_to_bytes(indices):
    arr = bytearray()
    for idx in indices:
        cp = int(idx) % 0x10FFFF
        if 0xD800 <= cp <= 0xDFFF:
            cp += 1
        arr.extend(cp.to_bytes(3,'big'))
    return bytes(arr)

# -------------------------------
# OPENCL HMAC
# -------------------------------
OPENCL_KERNEL = """
__kernel void dummy_hmac(__global uchar *in_data, __global uchar *out_hash){
    int gid = get_global_id(0);
    out_hash[gid] = in_data[gid] ^ 0xAA;
}
"""

def init_opencl():
    platform = cl.get_platforms()[0]
    device = platform.get_devices()[0]
    ctx = cl.Context([device])
    queue = cl.CommandQueue(ctx)
    program = cl.Program(ctx, OPENCL_KERNEL).build()
    kernel = cl.Kernel(program, "dummy_hmac")
    return ctx, queue, kernel

# -------------------------------
# ULTRA-BATCH EXPORT
# -------------------------------
def export_channels_ultrabatch(channel_list, out_file=EXPORT_BASE4096):
    global USE_OPENCL_HMAC
    print(f"🚀 Exporting batch of {len(channel_list)} channels...")

    all_indices = np.array([np.arange(SAMPLES_PER_CHANNEL, dtype=np.uint32) + (ch * MAX_SLOTS) // CHANNELS for ch in channel_list])
    all_bytes_list = [flatten_indices_to_bytes(row) for row in all_indices]
    offsets = np.cumsum([0] + [len(b) for b in all_bytes_list[:-1]])
    total_bytes = b''.join(all_bytes_list)

    ctx = queue = kernel = None
    hmac_digests = []

    if USE_OPENCL_HMAC:
        try:
            ctx, queue, kernel = init_opencl()
            print("⚡ OpenCL HMAC ultra-batch enabled")
            buf_in = cl.Buffer(ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=np.frombuffer(total_bytes, np.uint8))
            buf_out = cl.Buffer(ctx, cl.mem_flags.WRITE_ONLY, len(total_bytes))
            kernel(queue, (len(total_bytes),), None, buf_in, buf_out)
            result = np.empty_like(np.frombuffer(total_bytes, np.uint8))
            cl.enqueue_copy(queue, result, buf_out)
            queue.finish()
            for i, b in enumerate(all_bytes_list):
                hmac_digests.append(bytes(result[offsets[i]:offsets[i] + 32]))
        except Exception as e:
            print(f"⚠️ OpenCL failed: {e}, falling back to CPU")
            USE_OPENCL_HMAC = False

    if not USE_OPENCL_HMAC:
        with ThreadPoolExecutor() as exe:
            hmac_digests = list(exe.map(lambda b: hmac.new(HMAC_KEY, b, hashlib.sha256).digest(), all_bytes_list))

    with open(out_file, "a", encoding="utf-8") as f:
        for i, (bts, digest) in enumerate(zip(all_bytes_list, hmac_digests)):
            ch = channel_list[i]
            f.write(f"#Channel:{ch}\n")
            f.write(encode(bts) + "\n")
            f.write(f"#HMAC:{encode(digest)}\n")
    print(f"✅ Batch export complete for channels {channel_list}")

# -------------------------------
# JSON/BINARY EXPORTS
# -------------------------------
def export_recursive_vectors_json(vectors, outfile=EXPORT_JSON):
    def filter_surrogates(obj):
        if isinstance(obj, dict): return {k: filter_surrogates(v) for k,v in obj.items()}
        if isinstance(obj, list): return [filter_surrogates(x) for x in obj]
        if isinstance(obj, str): return ''.join(c for c in obj if 0xD800 > ord(c) or ord(c) > 0xDFFF)
        return obj
    safe_vectors = filter_surrogates(vectors)
    with open(outfile, "w", encoding="utf-8") as f:
        json.dump(safe_vectors, f, ensure_ascii=False, indent=2)
    print(f"✅ Exported {len(vectors)} vectors to JSON: {outfile}")

def export_binary_lattice(num_samples=LATTICE_WIDTH, outfile=EXPORT_BINARY):
    with open(outfile, "wb") as f:
        for idx in range(num_samples):
            val = (idx * 2654435761) % 4096 / 4096.0
            packed = struct.pack("fI", val, idx)
            f.write(packed)
    print(f"✅ Exported {num_samples} lattice slots to binary: {outfile}")

# -------------------------------
# OPENGL REAL-TIME FOLDING
# -------------------------------
VERTEX_SRC = """
#version 450 core
layout(location=0) in vec2 pos;
out vec2 texCoord;
void main(){ texCoord=(pos+1.0)*0.5; gl_Position=vec4(pos,0,1); }
"""

FRAGMENT_SRC = """
#version 450 core
in vec2 texCoord;
out vec4 fragColor;
uniform float omegaTime;
uniform float phiPowers[72];
uniform float threshold;
uniform int latticeWidth;
uniform int latticeHeight;
uniform int yOffset;
uniform int channelHighlight;
uniform float audioLevel;
uniform sampler2D videoTex;
uniform int hasVideo;

vec3 channelColors[24] = vec3[24](
vec3(1,0,0),vec3(0,1,0),vec3(0,0,1),vec3(1,1,0),vec3(1,0,1),vec3(0,1,1),
vec3(0.5,0,0),vec3(0,0.5,0),vec3(0,0,0.5),vec3(0.5,0.5,0),vec3(0.5,0,0.5),vec3(0,0.5,0.5),
vec3(0.25,0,0),vec3(0,0.25,0),vec3(0,0,0.25),vec3(0.25,0.25,0),vec3(0.25,0,0.25),vec3(0,0.25,0.25),
vec3(0.75,0,0),vec3(0,0.75,0),vec3(0,0,0.75),vec3(0.75,0.75,0),vec3(0.75,0,0.75),vec3(0,0.75,0.75)
);

float hash_float(int i,int seed){ uint ui=uint(i*374761393 + seed*668265263u); return float(ui & 0xFFFFFFFFu)/4294967295.0; }
vec3 computeVectorColor(int idx,float slot){ return channelColors[channelHighlight%24]*slot; }
float hdgl_slot(float val,float r_dim,float omega,int x,int y,int idx){
float resonance=(x%4==0?0.1*sin(omegaTime*0.05+float(y)):0.0) + audioLevel * 0.2;
float wave=(x%3==0?0.3:(x%3==1?0.0:-0.3));
float omega_inst=phiPowers[y%72];
float rec=r_dim*val*0.5+0.25*sin(omegaTime*r_dim+float(x));
float new_val=val+omega_inst+resonance+wave+rec+omega*0.05;
if(hasVideo == 1){
  vec4 videoColor = texture(videoTex, texCoord);
  float videoBrightness = (videoColor.r + videoColor.g + videoColor.b) / 3.0;
  new_val += videoBrightness * 0.2;
}
return new_val>threshold?1.0:0.0;
}

void main(){
int x=int(texCoord.x*float(latticeWidth));
int y=int(texCoord.y*float(latticeHeight))+yOffset;
int idx=y*latticeWidth+x;
float val=hash_float(idx,12345);
float r_dim=0.3+0.01*float(y);
float slot=hdgl_slot(val,r_dim,sin(omegaTime),x,y,idx);
vec3 color=computeVectorColor(idx,slot);
vec3 finalColor = color.rgb;
if(hasVideo == 1){
  vec4 videoColor = texture(videoTex, texCoord);
  float videoBrightness = (videoColor.r + videoColor.g + videoColor.b) / 3.0;
  finalColor *= videoBrightness + 0.5; // video modulates lattice
  finalColor = mix(finalColor, videoColor.rgb, val); // lattice modulates video
}
fragColor=vec4(finalColor,1.0);
}
"""

# -------------------------------
# OPENGL CONTROL
# -------------------------------
omega_time = 0.0
shader = None
yOffset = 0
current_channel = 0
previous_channel = -1
frame_count = 0
auto_scroll = True
audio_level = 0.0
current_rate = None
current_original_data = None
media_files = []
media_data_cache = {}
exported_channels = set()
video_texture = 0
current_clip = None
current_clip_duration = 0.0
current_frame = None

def get_media_for_channel(ch):
    if not media_files:
        return None
    file = media_files[ch % len(media_files)]
    return file['path']

def load_audio_data(media_path):
    try:
        if media_path.lower().endswith('.wav'):
            rate, orig = wavfile.read(media_path)
            if len(orig.shape) == 2:
                data = np.mean(orig, axis=1).astype(np.float32)
            else:
                data = orig.astype(np.float32)
            return rate, data
        else:
            try:
                audio = AudioSegment.from_file(media_path)
                print(f"ℹ️ Loaded audio from {media_path}: rate={audio.frame_rate}, channels={audio.channels}")
            except CouldntDecodeError:
                print(f"ℹ️ Load failed for {media_path}; retrying as M4A...")
                audio = AudioSegment.from_file(media_path, format='m4a')
                print(f"ℹ️ Loaded audio from {media_path} as M4A: rate={audio.frame_rate}, channels={audio.channels}")
            rate = audio.frame_rate
            data = np.array(audio.get_array_of_samples(), dtype=np.float32)
            if audio.channels == 2:
                data = data.reshape(-1, 2).mean(axis=1)
            return rate, data
    except Exception as e:
        print(f"⚠️ Failed to load audio from {media_path}: {e}")
        return None, None

def ensure_exported(channel_group):
    global exported_channels
    to_export = [ch for ch in channel_group if ch not in exported_channels]
    if to_export:
        export_channels_ultrabatch(to_export)
        for ch in to_export:
            exported_channels.add(ch)

def update_music():
    global previous_channel, current_rate, current_original_data, current_clip, current_clip_duration
    if current_channel == previous_channel:
        return
    pygame.mixer.music.stop()
    current_clip = None
    current_clip_duration = 0.0
    media_path = get_media_for_channel(current_channel)
    if media_path:
        print(f"🎵 Playing media for channel {current_channel}: {media_path}")
        if media_path not in media_data_cache:
            rate, data = load_audio_data(media_path)
            if rate is None:
                previous_channel = current_channel
                return
            media_data_cache[media_path] = (rate, data)
        rate, original_data = media_data_cache[media_path]
        # Modulate
        indices = np.arange(SAMPLES_PER_CHANNEL, dtype=np.uint32) + (current_channel * MAX_SLOTS) // CHANNELS
        mod_values = (indices * 2654435761 % 4096) / 4096.0
        num_repeats = math.ceil(len(original_data) / len(mod_values))
        mod_signal = np.tile(mod_values, num_repeats)[:len(original_data)]
        data_mod = (original_data * mod_signal).clip(-32768, 32767).astype(np.int16)
        modulated_wav = os.path.join(tempfile.gettempdir(), f"modulated_{current_channel}_{int(time.time())}.wav")
        try:
            pygame.mixer.quit()
            pygame.mixer.init()
            wavfile.write(modulated_wav, rate, data_mod)
            pygame.mixer.music.load(modulated_wav)
            pygame.mixer.music.play(-1)
            current_rate = rate
            current_original_data = original_data
            if media_path.lower().endswith(('.avi', '.mp4', '.mpeg')):
                try:
                    current_clip = VideoFileClip(media_path)
                    current_clip_duration = current_clip.duration
                    print(f"ℹ️ Loaded video from {media_path}: duration={current_clip_duration}s, size={current_clip.size}")
                except Exception as e:
                    print(f"⚠️ Failed to load video from {media_path}: {e}")
                    current_clip = None
                    current_clip_duration = 0.0
        except PermissionError as e:
            print(f"⚠️ Failed to write {modulated_wav}: {e}")
        except Exception as e:
            print(f"⚠️ Unexpected error during media processing: {e}")
    else:
        current_rate = None
        current_original_data = None
    previous_channel = current_channel

def display():
    global omega_time, yOffset, audio_level, current_frame, video_texture
    glClear(GL_COLOR_BUFFER_BIT)
    glUseProgram(shader)
    glUniform1f(glGetUniformLocation(shader,"omegaTime"),omega_time)
    glUniform1i(glGetUniformLocation(shader,"yOffset"),yOffset)
    glUniform1i(glGetUniformLocation(shader,"channelHighlight"),current_channel%24)
    glUniform1f(glGetUniformLocation(shader,"audioLevel"),audio_level)
    glUniform1i(glGetUniformLocation(shader,"hasVideo"), 1 if current_frame is not None else 0)
    if current_frame is not None:
        glActiveTexture(GL_TEXTURE0)
        glBindTexture(GL_TEXTURE2D, video_texture)
        glTexImage2D(GL_TEXTURE2D, 0, GL_RGB, current_frame.shape[1], current_frame.shape[0], 0, GL_RGB, GL_UNSIGNED_BYTE, current_frame)
    omega_time += 0.01
    glBegin(GL_TRIANGLES)
    glVertex2f(-1,-1)
    glVertex2f(3,-1)
    glVertex2f(-1,3)
    glEnd()
    glutSwapBuffers()
    glBindTexture(GL_TEXTURE2D, 0)

def idle():
    global yOffset, current_channel, frame_count, audio_level, current_frame
    frame_count += 1
    if auto_scroll and frame_count % 120 == 0:
        current_channel = (current_channel + 1) % CHANNELS
    channel_group = list(range((current_channel // 24) * 24, min(((current_channel // 24) + 1) * 24, CHANNELS)))
    ensure_exported(channel_group)
    update_music()
    yOffset = (current_channel % 24) * CHUNK_HEIGHT
    audio_level = 0.0
    current_frame = None
    if pygame.mixer.music.get_busy() and current_original_data is not None:
        pos = pygame.mixer.music.get_pos() / 1000.0
        sample_pos = int(pos * current_rate)
        if sample_pos < len(current_original_data):
            audio_level = abs(current_original_data[sample_pos]) / 32768.0
        if current_clip is not None:
            pos_sec = pos % current_clip_duration
            try:
                frame = current_clip.get_frame(pos_sec)
                current_frame = frame  # RGB numpy array
            except Exception as e:
                print(f"⚠️ Failed to get video frame at {pos_sec}s: {e}")
                current_frame = None
    glutPostRedisplay()

def keyboard(key, x, y):
    global current_channel, auto_scroll
    if key == b'w':
        current_channel = (current_channel - 1) % CHANNELS
        auto_scroll = False
    elif key == b's':
        current_channel = (current_channel + 1) % CHANNELS
        auto_scroll = False
    elif key == b'a':
        auto_scroll = not auto_scroll
    channel_group = list(range((current_channel // 24) * 24, min(((current_channel // 24) + 1) * 24, CHANNELS)))
    ensure_exported(channel_group)
    update_music()
    glutPostRedisplay()

def init_gl():
    global shader, video_texture
    vs = compileShader(VERTEX_SRC, GL_VERTEX_SHADER)
    fs = compileShader(FRAGMENT_SRC, GL_FRAGMENT_SHADER)
    shader = compileProgram(vs, fs)
    glUseProgram(shader)
    loc = glGetUniformLocation(shader, "phiPowers")
    glUniform1fv(loc, len(PHI_POWERS), PHI_POWERS)
    glUniform1f(glGetUniformLocation(shader, "threshold"), THRESHOLD)
    glUniform1i(glGetUniformLocation(shader, "latticeWidth"), LATTICE_WIDTH)
    glUniform1i(glGetUniformLocation(shader, "latticeHeight"), LATTICE_HEIGHT)
    glUniform1f(glGetUniformLocation(shader, "audioLevel"), 0.0)
    glUniform1i(glGetUniformLocation(shader, "videoTex"), 0)
    glUniform1i(glGetUniformLocation(shader, "hasVideo"), 0)
    video_texture = glGenTextures(1)
    glBindTexture(GL_TEXTURE2D, video_texture)
    glTexParameteri(GL_TEXTURE2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR)
    glTexParameteri(GL_TEXTURE2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
    glBindTexture(GL_TEXTURE2D, 0)

# -------------------------------
# MAIN
# -------------------------------
def main():
    global media_files, exported_channels
    print(f"🚀 Starting HDGL node ultra-batch with {CHANNELS} channels...")
    vectors = build_recursive_vectors()
    export_recursive_vectors_json(vectors)
    export_binary_lattice()

    # Load exported channels
    if os.path.exists(EXPORT_BASE4096):
        with open(EXPORT_BASE4096, "r", encoding="utf-8") as f:
            for line in f:
                if line.startswith("#Channel:"):
                    try:
                        ch = int(line[9:].strip())
                        exported_channels.add(ch)
                    except ValueError:
                        pass
    print(f"Loaded {len(exported_channels)} pre-exported channels.")

    # Load media files from music and movies folders
    media_files = []
    for folder, extensions in [('music', ('.wav', '.mp3')), ('movies', ('.avi', '.mp4', '.mpeg'))]:
        if os.path.exists(folder):
            files = sorted([f for f in os.listdir(folder) if f.lower().endswith(extensions)])
            media_files.extend([{'path': os.path.join(folder, f), 'type': folder} for f in files])
            print(f"Loaded {len(files)} {folder} files from '{folder}': {files}")
        else:
            print(f"⚠️ Directory '{folder}' not found.")
    if not media_files:
        print("⚠️ No media files found in 'music' or 'movies'. No audio/video will be played.")

    # Initialize OpenGL window
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE)
    glutInitWindowSize(1280, 720)
    glutCreateWindow(b"HDGL Streaming Node - Ultra-Batch OpenCL + OpenGL")
    init_gl()
    glutDisplayFunc(display)
    glutIdleFunc(idle)
    glutKeyboardFunc(keyboard)

    # Initialize Pygame for audio
    pygame.init()
    pygame.mixer.init()

    # Initial export and media
    channel_group = list(range(0, min(24, CHANNELS)))
    ensure_exported(channel_group)
    update_music()

    print("🖥 OpenGL folding running... (w/s to scroll, a to toggle auto-scroll)")
    glutMainLoop()

if __name__ == "__main__":
    main()
