import genesis as gs
import torch # GPU 백엔드 필수 (제거 금지)
import numpy as np
import math
import tkinter as tk
from tkinter import ttk, scrolledtext, messagebox, filedialog # filedialog 추가
import traceback
import sys
import time # 타이머용
import json # 저장용
import os

# ===========================================================
# 1. 시뮬레이션 설정
# ===========================================================
gs.init(backend=gs.gpu)

scene = gs.Scene(
    sim_options=gs.options.SimOptions(dt=0.01, substeps=10),
    rigid_options=gs.options.RigidOptions(
        max_collision_pairs=2000
    ),
    viewer_options=gs.options.ViewerOptions(
        camera_pos=(2.0, -2.0, 1.5), 
        camera_lookat=(0.0, 0.0, 0.5)
    ),
    show_viewer=True
)

plane = scene.add_entity(gs.morphs.Plane())

JOINT_NAMES = [
    "shoulder_pan_joint", "shoulder_lift_joint", "elbow_joint",
    "wrist_1_joint", "wrist_2_joint", "wrist_3_joint"
]

# ===========================================================
# 2. 로봇 및 마커 시스템
# ===========================================================
class RobotState:
    def __init__(self, idx, pos, name):
        self.idx = idx
        self.name = name
        self.surface = gs.surfaces.Default(color=(0.8, 0.8, 0.8), opacity=0.3)
        self.entity = scene.add_entity(
            gs.morphs.MJCF(file="xml/universal_robots_ur5e/ur5e.xml", pos=pos),
            surface=self.surface
        )
        self.active = False
        self.last_touched_idx = -1
        self.target_joints = np.zeros(6)
        self.dofs_idx = []
        self.stuck_timer = 0
        self.touched_history = []
        self.state = "NORMAL"
        self.tcp_marker_entity = None
        self.progress_idx = 0
        self.tcp_link_obj = None

robots = [
    RobotState(0, (-0.6, 0, 0), "Robot 1"),
    RobotState(1, (0.6, 0, 0), "Robot 2")
]

home_pose = np.array([0, -1.57, 1.57, -1.57, -1.57, 0], dtype=float)

# --- 마커 시스템 ---
N_POINTS = 20
trajectory_points = []
point_states = [0] * N_POINTS 

markers_yellow = []
markers_blue = []
markers_red = []
tcp_markers = []

def get_random_pos():
    return np.array([
        np.random.uniform(-0.15, 0.15), 
        np.random.uniform(-0.3, 0.3), 
        np.random.uniform(0.15, 0.5)
    ])

def get_hidden_pos(index, type_offset):
    return np.array([0, 0, -10.0 - (index * 0.1) - (type_offset * 5.0)])

for i in range(N_POINTS):
    pos = get_random_pos()
    trajectory_points.append(pos)
    
    my = scene.add_entity(gs.morphs.Sphere(radius=0.02, pos=pos, fixed=True, collision=False),
                          surface=gs.surfaces.Default(color=(1.0, 1.0, 0.0)))
    markers_yellow.append(my)
    
    mb = scene.add_entity(gs.morphs.Sphere(radius=0.021, pos=get_hidden_pos(i, 0), fixed=True, collision=False),
                          surface=gs.surfaces.Default(color=(0.0, 0.0, 1.0)))
    markers_blue.append(mb)
    
    mr = scene.add_entity(gs.morphs.Sphere(radius=0.021, pos=get_hidden_pos(i, 1), fixed=True, collision=False),
                          surface=gs.surfaces.Default(color=(1.0, 0.0, 0.0)))
    markers_red.append(mr)

for i in range(2):
    m = scene.add_entity(gs.morphs.Sphere(radius=0.015, pos=get_hidden_pos(i, 2), fixed=True, collision=False),
                         surface=gs.surfaces.Default(color=(0.0, 1.0, 0.0)))
    tcp_markers.append(m)

robots[0].tcp_marker_entity = tcp_markers[0]
robots[1].tcp_marker_entity = tcp_markers[1]

scene.build()

# --- 로봇 초기화 ---
for r in robots:
    found_idxs = []
    for j_name in JOINT_NAMES:
        try:
            joint = r.entity.get_joint(j_name)
            if joint is not None: found_idxs.append(joint.dof_idx_local)
        except: pass
    
    r.dofs_idx = found_idxs if len(found_idxs) == 6 else [0, 1, 2, 3, 4, 5]
    
    # 링크 객체 안전하게 가져오기
    try: r.tcp_link_obj = r.entity.get_link("ee_virtual_link")
    except: r.tcp_link_obj = r.entity.get_link("wrist_3_link")

    r.entity.set_qpos(torch.tensor(home_pose, device=gs.device))
    r.target_joints = np.copy(home_pose)
    
    # [수정] GPU(CUDA) 백엔드에서는 NumPy 배열을 직접 넣으면 시뮬레이션이 멈출 수 있습니다.
    # 반드시 Torch 텐서로 변환하여 입력해야 합니다.
    r.entity.set_dofs_kp(kp=torch.tensor([10000]*6, device=gs.device), dofs_idx_local=r.dofs_idx)
    r.entity.set_dofs_kv(kv=torch.tensor([1000]*6, device=gs.device), dofs_idx_local=r.dofs_idx)
    r.entity.set_dofs_force_range(
        lower=torch.tensor([-500]*6, device=gs.device), 
        upper=torch.tensor([500]*6, device=gs.device), 
        dofs_idx_local=r.dofs_idx
    )

# ===========================================================
# 3. [개선됨] Robust Hybrid IK (안전장치 포함)
# ===========================================================
def to_numpy(x):
    import torch
    if isinstance(x, torch.Tensor): return x.detach().cpu().numpy()
    return np.array(x)

def normalize_vec(v):
    norm = np.linalg.norm(v)
    if norm < 1e-6: return v
    return v / norm

def quat_to_matrix(q):
    # (w, x, y, z) -> 3x3 matrix
    w, x, y, z = q
    return np.array([
        [1 - 2*y**2 - 2*z**2, 2*x*y - 2*z*w, 2*x*z + 2*y*w],
        [2*x*y + 2*z*w, 1 - 2*x**2 - 2*z**2, 2*y*z - 2*x*w],
        [2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x**2 - 2*y**2]
    ])

def calculate_hybrid_ik(robot, target_pos, other_tcp_pos=None):
    """
    위치 제어 + 바라보기 제어 (Null Space Projection)
    에러 발생 시 멈추지 않고 0속도를 반환하거나 위치 제어만 수행.
    """
    tcp_link = robot.tcp_link_obj
    
    # 1. 상태 가져오기
    # [수정] RigidLink 객체는 get_pos()와 get_quat()를 사용합니다.
    curr_pos = to_numpy(tcp_link.get_pos())
    curr_quat = to_numpy(tcp_link.get_quat())
    curr_rot = quat_to_matrix(curr_quat)
    
    # 2. Jacobian
    J_all = to_numpy(robot.entity.get_jacobian(tcp_link))
    J_pos = J_all[:3, :] # Linear Part
    J_rot = J_all[3:, :] # Angular Part
    
    # 3. [Task 1] 위치 이동
    error_pos = target_pos - curr_pos
    dist = np.linalg.norm(error_pos)
    
    # Pseudo-Inverse (Damped) - 특이점 방지
    try:
        J_pos_pinv = np.linalg.pinv(J_pos, rcond=1e-3)
        dq_pos = J_pos_pinv @ error_pos
    except np.linalg.LinAlgError:
        # 수학적 에러 발생 시 정지
        return np.zeros(6), dist, curr_pos
    
    dq_total = dq_pos
    
    # 4. [Task 2] Facing (조건부 실행)
    if other_tcp_pos is not None:
        to_enemy_vec = other_tcp_pos - curr_pos
        enemy_dist = np.linalg.norm(to_enemy_vec)
        
        # 거리가 너무 가깝거나(0.05m 미만), 너무 멀면(0.6m 초과) 회전 무시
        if 0.05 < enemy_dist < 0.60:
            try:
                # 현재 Z축 (접근 방향)
                current_z = normalize_vec(curr_rot[:3, 2])
                # 목표 Z축 (상대방 방향)
                target_z = normalize_vec(to_enemy_vec)
                
                # 회전 오차 (Cross Product)
                rot_error = np.cross(current_z, target_z)
                
                # Null Space 계산
                I = np.eye(6)
                P_null = I - (J_pos_pinv @ J_pos)
                
                # 회전 속도 (Gain 1.5)
                dq_look = J_rot.T @ (rot_error * 1.5)
                
                # Null Space 투영 후 합산
                dq_projected = P_null @ dq_look
                dq_total += dq_projected
                
            except Exception:
                # 회전 계산 중 에러나면 무시하고 위치 이동만 수행
                pass

    # 5. 안전장치: NaN/Inf 제거 및 클리핑
    if np.any(np.isnan(dq_total)) or np.any(np.isinf(dq_total)):
        print("⚠️ IK Warning: NaN detected, stopping robot.")
        dq_total = np.zeros(6)
        
    dq_total = np.clip(dq_total, -5.0, 5.0)
    
    return dq_total, dist, curr_pos

# ===========================================================
# 4. GUI 패널
# ===========================================================
root = tk.Tk()
root.title("UR5e 통합 제어 (거리 조절 기능)")
root.geometry("400x800")

style = ttk.Style()
style.configure("Bold.TLabel", font=("Arial", 10, "bold"))

# [A] 로그
log_frame = ttk.LabelFrame(root, text=" 📜 시스템 로그 ", padding=(10, 5))
log_frame.pack(fill="x", padx=10, pady=5)
log_box = scrolledtext.ScrolledText(log_frame, height=8, state='disabled', font=("Consolas", 9))
log_box.pack(fill="x")

def log(msg):
    print(msg, flush=True)  # Added for console visibility
    if log_box:
        log_box.configure(state='normal')
        log_box.insert(tk.END, msg + "\n")
        log_box.see(tk.END)
        log_box.configure(state='disabled')

# [B] 제어
ctrl_frame = ttk.LabelFrame(root, text=" 🎮 미션 제어 ", padding=(10, 5))
ctrl_frame.pack(fill="x", padx=10, pady=5)

safety_dist_var = tk.DoubleVar(value=0.20) # 기본값 20cm
def update_safety_label(val):
    safety_lbl.config(text=f"🛡️ 충돌 감지 거리: {float(val):.2f} m")

safety_lbl = ttk.Label(ctrl_frame, text=f"🛡️ 충돌 감지 거리: {safety_dist_var.get():.2f} m")
safety_lbl.pack(anchor='w', pady=(5,0))
safety_slider = tk.Scale(
    ctrl_frame, 
    from_=0.10, to=0.50, 
    resolution=0.01, 
    orient='horizontal', 
    variable=safety_dist_var,
    command=update_safety_label
)
safety_slider.pack(fill='x', pady=5)

# [C] 정보
info_frame = ttk.LabelFrame(root, text=" 📊 터치 기록 ", padding=(10, 5))
info_frame.pack(fill="x", padx=10, pady=5)
r1_history_var = tk.StringVar(value="R1: []")
r2_history_var = tk.StringVar(value="R2: []")
mission_time_var = tk.StringVar(value="⏱️ 소요 시간: 0.00s") # 추가

ttk.Label(info_frame, textvariable=r1_history_var, foreground="blue", wraplength=350).pack(anchor='w', pady=2)
ttk.Label(info_frame, textvariable=r2_history_var, foreground="red", wraplength=350).pack(anchor='w', pady=2)
ttk.Label(info_frame, textvariable=mission_time_var, font=("Arial", 10, "bold")).pack(anchor='w', pady=5) # UI에 시간 표시

# ===========================================================
# 5. 시뮬레이션 로직
# ===========================================================
trajectory_running = False
mission_start_time = 0 # 추가: 시작 시간 추적용
speed_scale = 3.0
TIMEOUT_FRAMES = 300
TOUCH_TOLERANCE = 0.04 
SAFETY_DIST_TARGET = 0.25 
debug_timer = 0

def simulation_step():
    global debug_timer, trajectory_running
    
    # TCP 마커 업데이트
    tcp_positions = []
    for r in robots:
        try:
            p = to_numpy(r.tcp_link_obj.get_pos())
            r.tcp_marker_entity.set_pos(p)
            tcp_positions.append(p)
        except: tcp_positions.append(np.array([0,0,0]))

    if not trajectory_running:
        for r in robots: r.entity.control_dofs_position(r.target_joints, r.dofs_idx)
        return

    # 로봇 간 거리
    tcp1, tcp2 = tcp_positions[0], tcp_positions[1]
    robot_dist = np.linalg.norm(tcp1 - tcp2)

    # 회피/대기 로직
    r1, r2 = robots[0], robots[1]
    current_safety_dist = safety_dist_var.get()

    if r2.progress_idx < N_POINTS:
        target2_pos = trajectory_points[N_POINTS - 1 - r2.progress_idx]
    else: target2_pos = np.array([0,0,0])

    if r2.active and r1.active:
        if robot_dist < current_safety_dist:
            if r2.state != "RETREAT": log(f"🚨 충돌 위험! R2 방어자세 ({robot_dist*100:.0f}cm)")
            r2.state = "RETREAT"
        elif r2.progress_idx < N_POINTS and np.linalg.norm(target2_pos - tcp1) < SAFETY_DIST_TARGET:
            if r2.state != "WAIT": log("✋ R1 작업 중. R2 대기.")
            r2.state = "WAIT"
        else:
            if r2.state in ["RETREAT", "WAIT"]: log("✅ 경로 확보. R2 진입.")
            r2.state = "NORMAL"

    debug_timer += 1
    finished_count = 0

    for i, r_state in enumerate(robots):
        if not r_state.active: 
            r_state.entity.control_dofs_position(r_state.target_joints, r_state.dofs_idx)
            continue
            
        # 완료 시 홈 복귀
        if r_state.progress_idx >= N_POINTS:
            r_state.target_joints += (home_pose - r_state.target_joints) * 0.05
            r_state.entity.control_dofs_position(r_state.target_joints, r_state.dofs_idx)
            finished_count += 1
            continue

        if r_state.state == "RETREAT":
            # 후퇴하면서도 살짝 상대를 쳐다보도록
            r_state.target_joints += (home_pose - r_state.target_joints) * 0.05
            r_state.entity.control_dofs_position(r_state.target_joints, r_state.dofs_idx)
            continue
            
        if r_state.state == "WAIT":
            r_state.entity.control_dofs_position(r_state.target_joints, r_state.dofs_idx)
            continue
        
        # 목표 추적
        if i == 0: my_target_idx = r_state.progress_idx
        else: my_target_idx = N_POINTS - 1 - r_state.progress_idx
            
        target = trajectory_points[my_target_idx]
        other_tcp = tcp2 if i == 0 else tcp1
        
        # [강화된 IK 호출]
        dq, dist, curr_pos = calculate_hybrid_ik(r_state, target, other_tcp_pos=other_tcp)
        
        r_state.target_joints += dq * speed_scale * 0.01
        r_state.entity.control_dofs_position(r_state.target_joints, r_state.dofs_idx)

        # 도달 체크
        if dist < TOUCH_TOLERANCE:
            if r_state.last_touched_idx != my_target_idx:
                r_state.last_touched_idx = my_target_idx
                r_state.touched_history.append(my_target_idx)
                
                if i == 0: r1_history_var.set(f"R1: {r_state.touched_history}")
                else: r2_history_var.set(f"R2: {r_state.touched_history}")
                
                markers_yellow[my_target_idx].set_pos(get_hidden_pos(my_target_idx, 0)) 
                if i == 0:
                    markers_blue[my_target_idx].set_pos(trajectory_points[my_target_idx]) 
                    point_states[my_target_idx] = 1
                else:
                    markers_red[my_target_idx].set_pos(trajectory_points[my_target_idx])
                    point_states[my_target_idx] = 2

                r_state.stuck_timer = 0
                r_state.progress_idx += 1
                if r_state.progress_idx >= N_POINTS:
                    log(f"🎉 {r_state.name} 완료! 복귀.")
        else:
            if r_state.state == "NORMAL":
                r_state.stuck_timer += 1
                if r_state.stuck_timer > TIMEOUT_FRAMES: 
                    log(f"⚠️ [{r_state.name}] 스킵 #{my_target_idx}")
                    r_state.progress_idx += 1
                    r_state.stuck_timer = 0
        
        if debug_timer % 60 == 0:
            icon = "🛡️" if r_state.state == "RETREAT" else ("✋" if r_state.state=="WAIT" else "🚀")
            log(f"[{r_state.name}] {icon} 진행:{r_state.progress_idx}/{N_POINTS} | 거리:{dist*100:.1f}cm")

    if finished_count >= 2:
        elapsed = time.time() - mission_start_time
        msg = f"⏱️ 모든 미션 완료! 소요 시간: {elapsed:.2f}초"
        log(f"🎉🎉 {msg}")
        mission_time_var.set(msg) # UI 레이블 업데이트
        trajectory_running = False

# ===========================================================
# 6. 실행 제어
# ===========================================================
def reset_markers_visibility():
    global point_states
    point_states = [0] * N_POINTS
    for i in range(N_POINTS):
        # 노란색 마커는 원래 위치로
        markers_yellow[i].set_pos(trajectory_points[i])
        # 파란색/빨간색 마커는 숨김 위치로
        markers_blue[i].set_pos(get_hidden_pos(i, 0))
        markers_red[i].set_pos(get_hidden_pos(i, 1))

def regenerate_markers(reset_history=True):
    global trajectory_points
    new_points = []
    for i in range(N_POINTS):
        new_points.append(get_random_pos())
    trajectory_points = new_points
    
    reset_markers_visibility()
    if reset_history: log("🔄 궤적 재생성 완료")

def start_mission():
    global trajectory_running, mission_start_time
    log("▶ 미션 초기화 및 시작...")
    mission_start_time = time.time() 
    for r in robots:
        r.target_joints = np.copy(home_pose)
        r.entity.set_qpos(torch.tensor(home_pose, device=gs.device))
        r.entity.set_dofs_velocity(torch.zeros(6, device=gs.device), r.dofs_idx) # GPU 속도 초기화
        r.active = True
        r.last_touched_idx = -1
        r.stuck_timer = 0
        r.touched_history = []
        r.state = "NORMAL"
        r.progress_idx = 0

    r1_history_var.set("R1: []")
    r2_history_var.set("R2: []")
    mission_time_var.set("⏱️ 소요 시간: 측정 중...")
    # [수정] 미션 시작 시 포인트를 새로 생성하지 않고, 현재 포인트들(불러온 것 포함)의 상태만 리셋합니다.
    reset_markers_visibility()
    trajectory_running = True
    log("🚀 미션 시작! (타이머 작동)")

def test_wiggle():
    log("🔧 모터 테스트")
    for r in robots:
        r.active = False 
        r.target_joints += np.array([0.2, 0, 0, 0, 0, 0]) 
        r.entity.control_dofs_position(r.target_joints, r.dofs_idx)

def exit_simulation():
    global running; running = False
    try: root.destroy()
    except: pass
    sys.exit()

def save_trajectory():
    try:
        data = [p.tolist() for p in trajectory_points]
        filename = filedialog.asksaveasfilename(
            defaultextension=".json",
            filetypes=[("JSON files", "*.json"), ("All files", "*.*")],
            title="궤적 저장"
        )
        if not filename: return
        
        with open(filename, "w") as f:
            json.dump(data, f)
        log(f"💾 궤적 저장 완료: {os.path.basename(filename)}")
    except Exception as e:
        log(f"❌ 저장 실패: {e}")

def load_trajectory():
    try:
        filename = filedialog.askopenfilename(
            filetypes=[("JSON files", "*.json"), ("All files", "*.*")],
            title="궤적 불러오기"
        )
        if not filename: return
        
        with open(filename, "r") as f:
            data = json.load(f)
        
        global trajectory_points
        trajectory_points = [np.array(p) for p in data]
        
        # [수정] 리셋 로직 통합 호출
        reset_markers_visibility()
            
        log(f"📂 궤적 로드 완료: {os.path.basename(filename)}")
    except Exception as e:
        log(f"❌ 로드 실패: {e}")

ttk.Button(ctrl_frame, text="▶ 미션 시작 (Random 20)", command=start_mission).pack(fill='x', pady=2)

# 가로 버튼 프레임
btn_frame = ttk.Frame(ctrl_frame)
btn_frame.pack(fill='x', pady=2)
ttk.Button(btn_frame, text="🔄 재생성", command=lambda: regenerate_markers(False)).pack(side='left', expand=True, fill='x', padx=1)
ttk.Button(btn_frame, text="💾 저장", command=save_trajectory).pack(side='left', expand=True, fill='x', padx=1)
ttk.Button(btn_frame, text="📂 불러오기", command=load_trajectory).pack(side='left', expand=True, fill='x', padx=1)

ttk.Button(ctrl_frame, text="🔧 모터 테스트", command=test_wiggle).pack(fill='x', pady=2)
ttk.Separator(root, orient='horizontal').pack(fill='x', pady=10)
ttk.Button(root, text="❌ 프로그램 종료", command=exit_simulation).pack(fill='x', padx=20, pady=5)

def update_loop():
    try:
        if running:
            scene.step()
            simulation_step()
    except Exception as e:
        # 에러 발생 시에도 루프가 죽지 않도록 로깅만 하고 계속 진행
        print(f"Update Error: {e}")
        # traceback.print_exc() # 필요시 주석 해제하여 디버깅
    
    # root.after는 try-except 밖 또는 finally에 위치해야 루프가 영구 정지되지 않습니다.
    if running:
        root.after(16, update_loop)

running = True
root.protocol("WM_DELETE_WINDOW", exit_simulation)
log("시스템 준비 완료.")
root.after(16, update_loop)
root.mainloop()