import datetime
from datasets import load_dataset
import json
import logging
import os
from typing import *
import time
from dataclasses import dataclass
import sys

sys.path.append("/workspace/grpo_implement")
from reward_functions.reward_func_def import format_reward_func, countdown_acc_reward_func
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig, TrlParser
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from transformers.trainer_utils import get_last_checkpoint


# =======================================================
print(f"【pid】 = {os.getpid()}")

print(f"Running on device = {os.environ['CUDA_VISIBLE_DEVICES']}")
# =======================================================

# In case the root handler is not empty (cannot write to file)
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

log_path = "/workspace/grpo_implement/logs/countdown_zero_rl_aha_moment_dist.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)-8s | %(filename)s:%(lineno)d | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    filename=log_path,
    filemode='w'
)
logger = logging.getLogger(__file__)


# =======================================================
@dataclass
class ScriptArguments:
    dataset_id_or_path: str
    dataset_splits: str = "train"
    tokenizer_name_or_path: str = None


# =======================================================
def run_grpo(model_args, script_args, training_args):
    dataset = load_dataset("json", data_files=script_args.dataset_id_or_path, split="train")
    dataset = dataset.shuffle(seed=42)

    train_test_split = dataset.train_test_split(test_size=0.1)
    train_dataset = train_test_split["train"]
    test_dataset = train_test_split["test"]
    logger.info(
        f"Dataset splited. Number of training samples = {len(train_dataset)}, number of test samples = {len(test_dataset)}")

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)

    trainer = GRPOTrainer(
        model=model_args.model_name_or_path,
        # model=model,

        reward_funcs=[format_reward_func, countdown_acc_reward_func],
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        peft_config=get_peft_config(model_args),
    )

    # ========================================================
    def get_checkpoint(training_args: GRPOConfig):
        last_checkpoint = None
        if os.path.isdir(training_args.output_dir):
            last_checkpoint = get_last_checkpoint(training_args.output_dir)
        return last_checkpoint

    # Check for last checkpoint
    last_checkpoint = get_checkpoint(training_args)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Checkpoint DETECTED, resuming training at {last_checkpoint}.")
    else:
        logger.info(f"No checkpoint detected, train from scratch.")
    # ========================================================

    # Train
    logger.info(f'*** Starting training ***')
    train_result = trainer.train(resume_from_checkpoint=last_checkpoint)

    metrics = train_result.metrics
    metrics["train_samples"] = len(train_dataset)
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()

    logger.info("*** Training complete ***")

    # Save model
    logger.info("*** Save model ***")
    trainer.model.config.use_cache = True
    trainer.save_model(training_args.output_dir)
    logger.info(f"Model saved to {training_args.output_dir}")
    training_args.distributed_state.wait_for_everyone()  # wait for all processes to load

    tokenizer.save_pretrained(training_args.output_dir)
    logger.info(f"Tokenizer saved to {training_args.output_dir}")

    # Save everything else on main process
    if trainer.accelerator.is_main_process:
        trainer.create_model_card({"tags": ["rl", "grpo", 'aha_moment']})
    # push to hub if needed
    if training_args.push_to_hub is True:
        logger.info("Pushing to hub...")
        trainer.push_to_hub()

    logger.info("*** End of Training.***")


def main():
    parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
    model_args, script_args, training_args = parser.parse_args_and_config()
    logger.info(f"Model parameters {model_args}\n\n")
    logger.info(f"Training/evaluation parameters {training_args}")

    run_grpo(model_args, script_args, training_args)


if __name__ == "__main__":
    main()
