import torch
import time
import yaml
import json
import jsonlines
import traceback
import os
from fnmatch import fnmatch


def convert_torch_dtype(str_dtype):
    if str_dtype == "bf16":
        return torch.bfloat16
    elif str_dtype == "fp16":
        return torch.float16
    elif str_dtype == "fp32":
        return torch.float32
    else:
        raise Exception(f"Current function 'convert_torch_dtype()' does not support type string: {str_dtype}.")


def get_now_datetime(format_pattern='%Y-%m-%d %H:%M:%S'):
    return time.strftime(format_pattern, time.localtime(time.time()+60*60*8))   # in docker add 8 hours to match Beijing's clock time

def convert_datetime(the_time, format_pattern='%Y-%m-%d %H:%M:%S'):
    return time.strftime(format_pattern,time.localtime(the_time))

def read_yaml(yaml_path):
    with open(yaml_path) as stream:
        content = yaml.safe_load(stream)
    return content


def read_json(json_pth):
    with open(json_pth, "r") as fp:
        data = json.load(fp)
        # print(len(data))
    # print(f"--- Read JSON file: {json_pth}")
    return data

def write_json(dict_list, json_pth, ensure_ascii=False):
    with open(json_pth,'w') as json_ff:
        json_obj = json.dumps(dict_list, ensure_ascii=ensure_ascii, indent=4)
        json_ff.write(json_obj)
        print(f"*** Written JSON file to: {json_pth}")
        
def read_jsonl(jsonl_pth, len_limit=None):
    assert len_limit is None or isinstance(len_limit, int)
    outlist = []
    with open(jsonl_pth, "r+", encoding="utf8") as jl_ff:
        for idx, item in enumerate(jsonlines.Reader(jl_ff)):
            if len_limit and idx >= len_limit:
                break
            outlist.append(item)
    return outlist

def write_jsonl(list_of_dicts, jsonl_pth):
    with jsonlines.open(jsonl_pth, 'w') as out_file:
        out_file.write_all(list_of_dicts)

        
def add_jsonl(list_of_dicts, jsonl_pth, logger=None):
    if not isinstance(list_of_dicts, list):
        list_of_dicts = [list_of_dicts]

    with jsonlines.open(jsonl_pth, mode='a') as writer:
        for item in list_of_dicts:
            # try:
            writer.write(item)

def get_dir_filepaths_pattern(search_dir, pattern):
    the_paths = []
    for root, _, files in os.walk(search_dir):
        for filename in files:
            if fnmatch(filename, pattern): 
                f_pth = os.path.join(root, filename)
                if os.path.isfile(f_pth):
                    the_paths.append(f_pth)
    return the_paths                    
                    
def get_dir_filepaths(search_dir, end_str=None):
    the_paths = []

    for root, dirs, files in os.walk(search_dir):
        for filename in files:
            f_pth = os.path.join(root, filename)
            # if filename.endswith(end_str):
            #     print(f)
            if os.path.isfile(f_pth) and end_str and filename.endswith(end_str):
                the_paths.append(f_pth)
        
    return the_paths
