import json
import logging
import os
from typing import *
import time
import sys
import argparse
import torch
from tqdm import tqdm
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from basic_utils import read_json,get_dir_filepaths, add_jsonl,write_json
import requests,re
import jsonlines
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader
from PIL import Image
from PIL.Image import Image as ImageObject
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import csv

# In case the root handler is not empty (cannot write to file)
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

log_path = "/workspace/grpo_implement/logs/eval_countdown.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)-8s | %(filename)s:%(lineno)d | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    filename=log_path, 
    filemode='a'
)
logger = logging.getLogger(__file__)


# =======================================================

results_csv_summary_path = "/workspace/grpo_implement/evaluation/results/eval_countdown_summary.csv"

# -------- Initialize record csv --------
if not os.path.exists(results_csv_summary_path):
    with open(results_csv_summary_path, 'w') as res_file:
        f_writer = csv.writer(res_file, delimiter=',')
        f_writer.writerow(["Eval Prefix", "Correct Count", "Accuracy", "Accuracy (%)", "Model Path", "Dataset Path", "Gen kwargs", "Eval Num"])

# =======================================================


gen_kwargs = {        
    # "max_new_tokens":1024, 
    "max_tokens":2048, 
    "temperature": 0.0,
    # "top_k": 50,
    # "top_p": 0.95, 
}


parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str)
parser.add_argument("--dataset_path", type=str)
parser.add_argument("--save_prefix", type=str)
parser.add_argument("--save_root", type=str, default="/workspace/grpo_implement/evaluation/results")
parser.add_argument("--eval_num", type=int)

# parser.add_argument("--batch_size", type=int, default=1)

cli_args = parser.parse_args()
model_path = cli_args.model_path
dataset_path = cli_args.dataset_path
eval_num = cli_args.eval_num
save_prefix = cli_args.save_prefix
save_root = cli_args.save_root
# batch_size = cli_args.batch_size

logger.info("\n\n" + "#"*100 + f"\n[model_path]={model_path}\n[dataset_path]={dataset_path}\n[eval_num]={eval_num}\n[save_root]={save_root}\n[save_prefix]={save_prefix}\n[gen_kwargs]={gen_kwargs}\n")

save_path = os.path.join(save_root, f"eval_{eval_num}_{save_prefix}.json")
logger.info(f"Will save results to: {save_path}")

# =======================================================
start_time = time.time()

eval_dataset = load_dataset("json", data_files=dataset_path, split="train")
logger.info(f"Loaded eval dataset with {len(eval_dataset)} samples.")
eval_dataset = eval_dataset.select(range(eval_num))
logger.info(f"Selected {len(eval_dataset)} samples.")

tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info(f"Tokenizer loaded.")


# model = LLM.from_pretrained("model_path")

engine_args = {
    "model": model_path,
    "trust_remote_code": True,
    "dtype": "bfloat16",
    "tensor_parallel_size": torch.cuda.device_count(),
    "disable_log_stats": True,
    "max_lora_rank": 128,          
}

sampling_params = SamplingParams(**gen_kwargs)
logger.info(f"Model loaded.")


def collate_fn(batch):
    prompts = [item["prompt"] for item in batch] 
    tokenized = tokenizer(prompts, return_tensors="pt", padding=True) 

    for key in batch[0]:  
        tokenized[key] = [item[key] for item in batch]  

    return tokenized


dataloader = DataLoader(eval_dataset, batch_size=len(eval_dataset), collate_fn=collate_fn)

# Test sample keys >>>>>> ["id", "prompt", "target":, "nums", "test_id"]
save_list = []
total_score = 0
format_error_num = 0

for idx, sample_batch in enumerate(dataloader):
    inputs = [{"prompt_token_ids": item} for item in sample_batch["input_ids"]]
    results = LLM(**engine_args).generate(inputs, sampling_params)
    
    preds = [result.outputs[0].text for result in results]
    
    for batch_idx, gen_str in enumerate(preds):
        test_sample = {kk: vv[batch_idx] for kk, vv in sample_batch.items()}
        answer_match = re.search(r"<answer>(.*?)<\/answer>", gen_str)

        if answer_match is None:
            format_error_num += 1
            eval_dict = {
                "test_id": test_sample['test_id'],
                "score": 0,
                "prompt": test_sample['prompt'],
                "nums": test_sample['nums'],
                "generation": gen_str,
                "gen_calculation": None,
                "target": test_sample['target'],
            }
            
        else:
            answer_content = answer_match.group(1).strip()
            # print(f"answer_content = {answer_content}")
            try:
                answer_result = eval(answer_content, {"__builti'ns__": None}, {})
                # print(f"answer_result")
                if abs(float(answer_result) - float(test_sample['target'])) < 1e-5:
                    score = 1
                    total_score += 1
                else:
                    score = 0
                
                eval_dict = {
                    "test_id": test_sample['test_id'],
                    "score": score,
                    "prompt": test_sample['prompt'],
                    "nums": test_sample['nums'],
                    "generation": gen_str,
                    "gen_calculation": answer_result,
                    "target": test_sample['target'],
                }
                # print(f"eval_dict = {eval_dict}")
                
                # break
            except Exception as e:
                logger.warning(f"[Exception] at test id = {test_sample['test_id']}, message: {e}")
                eval_dict = {
                    "test_id": test_sample['test_id'],
                    "score": 0,
                    "prompt": test_sample['prompt'],
                    "nums": test_sample['nums'],
                    "generation": gen_str,
                    "gen_calculation": None,
                    "target": test_sample['target'],
                }
            
        logger.info(f"[Test ID = {test_sample['test_id']}]\t| {json.dumps(eval_dict, ensure_ascii=False)}")
        save_list.append(eval_dict)
        
        
logger.info(f"【 model_path 】= {model_path}")
logger.info(f"【 dataset_path 】= {dataset_path}")
logger.info("【 Total score 】 =\t{:d}/{:d} ({:.2f}%)".format(total_score, eval_num, total_score/eval_num*100))
logger.info("【 Format error counts 】 =\t{:d}/{:d} ({:.2f}%)".format(format_error_num, eval_num, format_error_num/eval_num*100))
write_json(save_list, save_path)

with open(results_csv_summary_path, 'a') as res_file:
    f_writer = csv.writer(res_file, delimiter=',')
    # f_writer.writerow(["Eval Prefix", "Correct Count", "Accuracy", "Accuracy (%)", "Model Path", "Dataset Path", "Gen kwargs", "Eval Num"])
    f_writer.writerow([save_prefix, total_score, round(total_score/eval_num, 3), "{:.2f}%".format(total_score/eval_num*100), model_path, dataset_path, gen_kwargs, eval_num])
    logger.info(f"Appended result summary to csv file: {results_csv_summary_path}")
    
end_time = time.time()
logger.info("Total time consumption = {:.1f}".format(end_time-start_time))