# from datacanvas.aps import dc
import os
import random
import torch
import sys
import json
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM, AutoConfig
from pathlib import Path
from torch.utils.data import Dataset, random_split
import re
import copy
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
)
# from datacanvas import doc
from peft import (
    LoraConfig,
    get_peft_model,
    PeftModel,
)
import logging


logger = logging.getLogger(__name__)
IGNORE_INDEX = -100
os.environ["WANDB_DISABLED"] = "true"
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def tokenize(text, tokenizer):
    inputs_with_offsets = tokenizer(text, return_offsets_mapping=True)
    labels = copy.deepcopy(inputs_with_offsets['input_ids'])
    offsets = inputs_with_offsets["offset_mapping"]
    matches = re.finditer(r'(\n Assistant:)(.*?)</s>', text, re.DOTALL)

    idx = []
    for match in matches:
        start_pos, end_pos = match.span()
        start_idx = None
        end_idx = None

        for i, (start, end) in enumerate(offsets):
            if start <= start_pos < end+1:
                start_idx = i
            if start <= end_pos < end+1:
                end_idx = i

            if start_idx is not None and end_idx is not None:
                idx.extend([j for j in range(start_idx, int(end_idx))])
    idx = list(set(idx))
    if len(idx) > 0:
        for k in range(len(labels)):
            if k not in idx:
                labels[k] = IGNORE_INDEX
    labels[-1] = 2
    return dict(
        input_ids=inputs_with_offsets['input_ids'],
        attention_mask=inputs_with_offsets['attention_mask'],
        labels=labels,
    )

class data_sets(Dataset):
    def __init__(self, txt_list, tokenizer):
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        for i, txt in enumerate(txt_list):
            encodings_dict = tokenize(txt, tokenizer)
            self.input_ids.append(encodings_dict['input_ids'])
            self.attn_masks.append( encodings_dict['attention_mask'])
            self.labels.append(encodings_dict['labels'])
            print(i)
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx], self.labels[idx]
    
    
def the_collate_fn(batch):
    maxlength = max([len(f[0]) for f in batch])
    input_ids = torch.stack([torch.tensor([3]*(maxlength-len(f[0])) + f[0]) for f in batch])
    attention_mask = torch.stack([torch.tensor([0]*(maxlength-len(f[1])) + f[1]) for f in batch])
    labels = torch.stack([torch.tensor([-100]*(maxlength-len(f[2])) + f[2]) for f in batch])
    return {'input_ids':input_ids, 'attention_mask':attention_mask, 'labels':labels}


class Mytrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(inputs["input_ids"], attention_mask=inputs["attention_mask"], labels = inputs["labels"])
        loss, logits = outputs[:2]
        return (loss, logits) if return_outputs else loss
        

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )


    input_dataset_path: Optional[str] = field(
        default="data.json",
    )

    train_size: Optional[float] = field(
        default=0.8,
    )
    max_seq_length: Optional[int] = field(
        default=256,
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )

    LORA_R: Optional[int] = field(
        default=2,
    )

    LORA_ALPHA: Optional[int] = field(
        default=4,
    )

    LORA_DROPOUT: Optional[float] = field(
        default=0.05,
    )

    TARGET_MODULES: Optional[str] = field(
        default="['q_proj']",
    )

    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )


def main():
    parser = HfArgumentParser((ModelArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        args, training_args = parser.parse_args_into_dataclasses()

    # if args.model_name_or_path is not None:
        # from datacanvas.model import model_repo_client
        # model_path = model_repo_client.get_model_path(args.model_name_or_path)
        # args.model_name_or_path=model_path
        
    # if args.tokenizer_name is not None:
    #     from datacanvas.model import model_repo_client
    #     model_path = model_repo_client.get_model_path(args.tokenizer_name)
    #     args.tokenizer_name = model_path

    train_size=args.train_size
    input_dataset_path=args.input_dataset_path
    max_seq_length=args.max_seq_length
    deepspeed=training_args.deepspeed

    # output_dir=str(dc.output.model_dir)+"/"+str(training_args.output_dir)
    output_dir=str(training_args.output_dir)

    LORA_R = args.LORA_R
    LORA_ALPHA = args.LORA_ALPHA
    LORA_DROPOUT = args.LORA_DROPOUT
    # TARGET_MODULES = ["query_key_value"]
    # TARGET_MODULES = ["q_proj","v_proj","k_proj","o_proj","gate_proj","down_proj","up_proj"]
    TARGET_MODULES = eval(args.TARGET_MODULES)
    
    config_kwargs = {
        "trust_remote_code": True,
        "use_cache" : False
    }
    if args.config_name:
        config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
    elif args.model_name_or_path:
        config = AutoConfig.from_pretrained(args.model_name_or_path, **config_kwargs)
    else:
        config = CONFIG_MAPPING[args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
        if args.config_overrides is not None:
            logger.info(f"Overriding config: {args.config_overrides}")
            config.update_from_string(args.config_overrides)
            logger.info(f"New config: {config}")

    tokenizer_kwargs = {
        "trust_remote_code": True
    }
    if args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, **tokenizer_kwargs)
    elif args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, **tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if args.model_name_or_path:
        torch_dtype = (
            args.torch_dtype
            if args.torch_dtype in ["auto", None]
            else getattr(torch, args.torch_dtype)
        )
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            trust_remote_code=True,
            torch_dtype=torch_dtype,
            # low_cpu_mem_usage=True,
        )
    else:
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
        n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
        logger.info(f"Training new model - Total size={n_params/2**20:.2f}M params")

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))
    print(model)

    config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM",
        )
    model = get_peft_model(model, config)
    model.print_trainable_parameters()

    instructions = []
    doc_path = input_dataset_path
    if os.path.isdir(doc_path):
        all_data_path = []
        data_files_jsonl = Path(doc_path).glob("**/*.jsonl") 
        data_files_json = Path(doc_path).glob("**/*.json") 
        try:
            for djl in data_files_jsonl:
                with open(djl, encoding='utf-8') as f:
                    for line in f:
                        data = json.loads(line)
                        instructions.append(data)
        except StopIteration:
            print("The iterator has reached the end")
            
        try:
            for dj in data_files_json:
                instructions += json.load(open(dj,"r",encoding="utf-8"))
        except StopIteration:
            print("The iterator has reached the end")
    elif os.path.isfile(doc_path):
        if doc_path.endswith(".json"):
            instructions += json.load(open(doc_path,"r",encoding="utf-8"))
        if doc_path.endswith(".jsonl"):
            with open(doc_path, encoding='utf-8') as f:
                for line in f:
                    data = json.loads(line)
                    instructions.append(data)

    random.shuffle(instructions)
    instruction_list = []
    for line in instructions:
        k = tokenizer.encode(line)
        if len(k) <= max_seq_length:
            instruction_list.append(line)
    dataset = data_sets(instruction_list*8000, tokenizer)
    train_size = int(train_size * len(dataset))
    train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
    print("train_dataset:",len(train_dataset))
    print("val_dataset:",len(val_dataset))

    from aim.hugging_face import AimCallback
    aim_callback = AimCallback()
    trainer = Mytrainer(model=model,
                        args=training_args,
                        train_dataset=train_dataset,
                        eval_dataset=val_dataset,
                        data_collator=the_collate_fn,
                        callbacks=[aim_callback],
                        )
    trainer.train()
    trainer.save_model(output_dir)

    print("Llama SFT Done")
    
if __name__ == "__main__":
    main()
