import json
import io
import base64
import numpy as np
import cv2
import torch
import torchxrayvision as xrv
from torchvision import models
import os
from openai import OpenAI
from medical_agent import VascularConsultantAgent
from settings import OPENAI_API_KEY, OPENAI_MODEL

class MedicalBrainEngine:
    def __init__(self):
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        self.agent = VascularConsultantAgent(api_key=OPENAI_API_KEY)
        
        self.model = models.mobilenet_v3_small(pretrained=True)
        self.model.eval()
        self.model.cpu()

    def process_xray_with_torch(self, base64_str):
        try:
            encoded = base64_str.split(",", 1)[-1]
            img_bytes = base64.b64decode(encoded)
            arr = np.frombuffer(img_bytes, np.uint8)
            img = cv2.imdecode(arr, cv2.IMREAD_GRAYSCALE)
            
            img = xrv.datasets.normalize(img, 255)
            img = img[None, :, :]
            img = xrv.datasets.XRayCenterCrop()(img)
            img = xrv.datasets.XRayResizer(224)(img)
            x = torch.from_numpy(img).unsqueeze(0)

            with torch.no_grad():
                pred = self.xrv_model(x)[0].cpu().numpy()
            findings = [{"name": p, "score": round(float(s), 4)} 
                        for p, s in zip(self.xrv_model.pathologies, pred)]
            findings.sort(key=lambda x: x["score"], reverse=True)
            return json.dumps(findings[:5])
        except Exception as e:
            return f"XRV Analysis Error: {str(e)}"

    def stage_1_vision_analysis(self, base64_image):
        print("[Vision Analysis] Running local XRV inference...")
        
        xrv_findings = self.process_xray_with_torch(base64_image)
        
        try:
            response = self.client.chat.completions.create(
                model=OPENAI_MODEL,
                messages=[{
                    "role": "user",
                    "content": f"Analyze these medical findings extracted from an X-ray: {xrv_findings}. Summarize clinical significance regarding vessel patency and stenosis."
                }]
            )
            return response.choices[0].message.content
        except Exception as e:
            return f"Vision Logic Error: {str(e)}"

    def run_full_diagnostic_direct(self, full_data):
        patient_info = full_data.get('data', {})
        images = patient_info.get('radiology_images', [])
        
        if images:
            visual_findings = self.stage_1_vision_analysis(images[-1])
        else:
            visual_findings = "No images provided."
            
        return self.agent.generate_final_audit(visual_findings, patient_info)