#!/usr/bin/env python
"""Namelist creator for CIME's driver.
"""
# Typically ignore this.
# pylint: disable=invalid-name

# Disable these because this is our standard setup
# pylint: disable=wildcard-import,unused-wildcard-import,wrong-import-position

import os, shutil, sys, glob, itertools

_CIMEROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..","..","..","..")
sys.path.append(os.path.join(_CIMEROOT, "scripts", "Tools"))

from standard_script_setup import *
from CIME.case import Case
from CIME.nmlgen import NamelistGenerator
from CIME.utils import expect
from CIME.utils import get_model, get_time_in_seconds, get_timestamp
from CIME.buildnml import create_namelist_infile, parse_input
from CIME.XML.files import Files

logger = logging.getLogger(__name__)

###############################################################################
def _create_drv_namelists(case, infile, confdir, nmlgen, files):
###############################################################################

    #--------------------------------
    # Set up config dictionary
    #--------------------------------
    config = {}
    config['cime_model'] = get_model()
    config['BGC_MODE'] = case.get_value("CCSM_BGC")
    config['CPL_I2O_PER_CAT'] = case.get_value('CPL_I2O_PER_CAT')
    config['COMP_RUN_BARRIERS'] = case.get_value('COMP_RUN_BARRIERS')
    config['DRV_THREADING'] = case.get_value('DRV_THREADING')
    config['CPL_ALBAV'] = case.get_value('CPL_ALBAV')
    config['CPL_EPBAL'] = case.get_value('CPL_EPBAL')
    config['FLDS_WISO'] = case.get_value('FLDS_WISO')
    config['BUDGETS'] = case.get_value('BUDGETS')
    config['MACH'] = case.get_value('MACH')
    config['MPILIB'] = case.get_value('MPILIB')
    config['OS'] = case.get_value('OS')
    config['glc_nec'] = 0 if case.get_value('GLC_NEC') == 0 else case.get_value('GLC_NEC')
    config['single_column'] = 'true' if case.get_value('PTS_MODE') else 'false'
    config['timer_level'] = 'pos' if case.get_value('TIMER_LEVEL') >= 1 else 'neg'
    config['bfbflag'] = 'on' if case.get_value('BFBFLAG') else 'off'
    config['continue_run'] = '.true.' if case.get_value('CONTINUE_RUN') else '.false.'

    if case.get_value('RUN_TYPE') == 'startup':
        config['run_type'] = 'startup'
    elif case.get_value('RUN_TYPE') == 'hybrid':
        config['run_type'] = 'startup'
    elif case.get_value('RUN_TYPE') == 'branch':
        config['run_type'] = 'branch'

    #----------------------------------------------------
    # Initialize namelist defaults
    #----------------------------------------------------
    nmlgen.init_defaults(infile, config)

    #--------------------------------
    # Overwrite: set brnch_retain_casename
    #--------------------------------
    start_type = nmlgen.get_value('start_type')
    if start_type != 'startup':
        if case.get_value('CASE') == case.get_value('RUN_REFCASE'):
            nmlgen.set_value('brnch_retain_casename' , value='.true.')

    #--------------------------------
    # Overwrite: set component coupling frequencies
    #--------------------------------
    ncpl_base_period  = case.get_value('NCPL_BASE_PERIOD')
    if ncpl_base_period == 'hour':
        basedt = 3600
    elif ncpl_base_period == 'day':
        basedt = 3600 * 24
    elif ncpl_base_period == 'year':
        if case.get_value('CALENDAR') == 'NO_LEAP':
            basedt = 3600 * 24 * 365
        else:
            expect(False, "Invalid CALENDAR for NCPL_BASE_PERIOD %s " %ncpl_base_period)
    elif ncpl_base_period == 'decade':
        if case.get_value('CALENDAR') == 'NO_LEAP':
            basedt = 3600 * 24 * 365 * 10
        else:
            expect(False, "invalid NCPL_BASE_PERIOD NCPL_BASE_PERIOD %s " %ncpl_base_period)
    else:
        expect(False, "invalid NCPL_BASE_PERIOD NCPL_BASE_PERIOD %s " %ncpl_base_period)

    if basedt < 0:
        expect(False, "basedt invalid overflow for NCPL_BASE_PERIOD %s " %ncpl_base_period)

    comps = case.get_values("COMP_CLASSES")
    mindt = basedt
    for comp in comps:
        ncpl = case.get_value(comp.upper() + '_NCPL')
        if ncpl is not None:
            cpl_dt = basedt / int(ncpl)
            totaldt = cpl_dt * int(ncpl)
            if totaldt != basedt:
                expect(False, " %s ncpl doesn't divide base dt evenly" %comp)
            nmlgen.add_default(comp.lower() + '_cpl_dt', value=cpl_dt)
            mindt = min(mindt, cpl_dt)
#        elif comp.lower() is not 'cpl':

    #--------------------------------
    # Overwrite: set start_ymd
    #--------------------------------
    run_startdate = "".join(str(x) for x in case.get_value('RUN_STARTDATE').split('-'))
    nmlgen.set_value('start_ymd', value=run_startdate)

    #--------------------------------
    # Overwrite: set tprof_option and tprof_n - if tprof_total is > 0
    #--------------------------------
    # This would be better handled inside the alarm logic in the driver routines.
    # Here supporting only nday(s), nmonth(s), and nyear(s).

    stop_option = case.get_value('STOP_OPTION')
    if 'nyear' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 365
    elif 'nmonth' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 30
    elif 'nday' in stop_option:
        tprofoption = 'ndays'
        tprofmult = 1
    else:
        tprofmult = 1
        tprofoption = 'never'

    tprof_total = case.get_value('TPROF_TOTAL')
    if ((tprof_total > 0) and (case.get_value('STOP_DATE') < 0) and ('ndays' in tprofoption)):
        stop_n = case.get_value('STOP_N')
        stopn = tprofmult * stop_n
        tprofn = int(stopn / tprof_total)
        if tprofn < 1:
            tprofn = 1
        nmlgen.set_value('tprof_option', value=tprofoption)
        nmlgen.set_value('tprof_n'     , value=tprofn)

    # Set up the pause_component_list if pause is active
    pauseo = case.get_value('PAUSE_OPTION')
    if pauseo != 'never' and pauseo != 'none':
        pausen = case.get_value('PAUSE_N')
        pcl = nmlgen.get_default('pause_component_list')
        nmlgen.add_default('pause_component_list', pcl)
        # Check to make sure pause_component_list is valid
        pcl = nmlgen.get_value('pause_component_list')
        if pcl != 'none' and pcl != 'all':
            pause_comps = pcl.split(':')
            comp_classes = case.get_values("COMP_CLASSES")
            for comp in pause_comps:
                expect(comp == 'drv' or comp.upper() in comp_classes,
                       "Invalid PAUSE_COMPONENT_LIST, %s is not a valid component type"%comp)
            # End for
        # End if
        # Set esp interval
        if 'nstep' in pauseo:
            esp_time = mindt
        else:
            esp_time = get_time_in_seconds(pausen, pauseo)

        nmlgen.set_value('esp_cpl_dt', value=esp_time)
    # End if pause is active

    #--------------------------------
    # (1) Write output namelist file drv_in and  input dataset list.
    #--------------------------------
    data_list_path = os.path.join(case.get_case_root(), "Buildconf", "cpl.input_data_list")
    if os.path.exists(data_list_path):
        os.remove(data_list_path)
    namelist_file = os.path.join(confdir, "drv_in")
    nmlgen.write_output_file(namelist_file, data_list_path )

    #--------------------------------
    # (2) Write out seq_map.rc file
    #--------------------------------
    seq_maps_file = os.path.join(confdir, "seq_maps.rc")
    nmlgen.write_seq_maps(seq_maps_file)

    #--------------------------------
    # (3) Construct and write out drv_flds_in
    #--------------------------------
    # In thte following, all values come simply from the infiles - no default values need to be added
    # FIXME - do want to add the possibility that will use a user definition file for drv_flds_in

    caseroot = case.get_value('CASEROOT')

    namelist_file = os.path.join(confdir, "drv_flds_in")

    nmlgen.add_default('drv_flds_in_files')
    drvflds_files = nmlgen.get_default('drv_flds_in_files')
    infiles = []
    for drvflds_file in drvflds_files:
        infile = os.path.join(caseroot, drvflds_file)
        if os.path.isfile(infile):
            infiles.append(infile)

    if len(infiles) != 0:

        # First read the drv_flds_in files and make sure that
        # for any key there are not two conflicting values
        dicts = {}
        for infile in infiles:
            dict_ = {}
            with open(infile) as myfile:
                for line in myfile:
                    if "=" in line and '!' not in line:
                        name, var = line.partition("=")[::2]
                        name = name.strip()
                        var = var.strip()
                        dict_[name] = var
            dicts[infile] = dict_

        for first,second in itertools.combinations(dicts.keys(),2):
            compare_drv_flds_in(dicts[first], dicts[second], first, second)

        # Now create drv_flds_in
        config = {}
        definition_file = [files.get_value("NAMELIST_DEFINITION_FILE", attribute={"component":"drv_flds"})]
        nmlgen = NamelistGenerator(case, definition_file, files=files)
        skip_entry_loop = True
        nmlgen.init_defaults(infiles, config, skip_entry_loop=skip_entry_loop)
        drv_flds_in = os.path.join(caseroot, "CaseDocs", "drv_flds_in")
        nmlgen.write_output_file(drv_flds_in)

###############################################################################
def compare_drv_flds_in(first, second, infile1, infile2):
###############################################################################
    sharedKeys = set(first.keys()).intersection(second.keys())
    for key in sharedKeys:
        if first[key] != second[key]:
            print('Key: {}, \n Value 1: {}, \n Value 2: {}'.format(key, first[key], second[key]))
            expect(False, "incompatible settings in drv_flds_in from \n %s \n and \n %s"
                   % (infile1, infile2))

###############################################################################
def _create_component_modelio_namelists(case, files):
###############################################################################

    # will need to create a new namelist generator
    infiles = []
    definition_file = [files.get_value("NAMELIST_DEFINITION_FILE", attribute={"component":"modelio"})]

    confdir = os.path.join(case.get_value("CASEBUILD"), "cplconf")
    lid = os.environ["LID"] if "LID" in os.environ else get_timestamp("%y%m%d-%H%M%S")

    models = ["cpl", "atm", "lnd", "ice", "ocn", "glc", "rof", "wav", "esp"]
    for model in models:
        with NamelistGenerator(case, definition_file) as nmlgen:
            config = {}
            config['component'] = model
            entries = nmlgen.init_defaults(infiles, config, skip_entry_loop=True)

            if model == 'cpl':
                inst_count = 1
            else:
                inst_count = case.get_value("NINST_" + model.upper())

            inst_index = 1
            while inst_index <= inst_count:
                inst_string = inst_index
                if inst_index <= 999:
                    inst_string = "0" + str(inst_string)
                if inst_index <=  99:
                    inst_string = "0" + str(inst_string)
                if inst_index <=  9:
                    inst_string = "0" + str(inst_string)

                # set default values
                for entry in entries:
                    nmlgen.add_default(entry.get("id"))

                # overwrite defaults
                moddiri = case.get_value('EXEROOT') + "/" + model
                nmlgen.set_value('diri', moddiri)

                moddiro = case.get_value('RUNDIR')
                nmlgen.set_value('diro', moddiro)

                if inst_count > 1:
                    logfile = model + "_" + inst_string + ".log." + str(lid)
                else:
                    logfile = model + ".log." + str(lid)
                nmlgen.set_value('logfile', logfile)

                # Write output file
                if inst_count > 1:
                    modelio_file = model + "_modelio.nml_" + str(inst_string)
                else:
                    modelio_file = model + "_modelio.nml"
                nmlgen.write_modelio_file(os.path.join(confdir, modelio_file))

                inst_index = inst_index + 1

###############################################################################
def buildnml(case, caseroot, component):
###############################################################################
    if component != "drv":
        raise AttributeError

    confdir = os.path.join(case.get_value("CASEBUILD"), "cplconf")
    if not os.path.isdir(confdir):
        os.makedirs(confdir)

    # NOTE: User definition *replaces* existing definition.
    # TODO: Append instead of replace?
    user_xml_dir = os.path.join(caseroot, "SourceMods", "src.drv")

    expect (os.path.isdir(user_xml_dir),
            "user_xml_dir %s does not exist " %user_xml_dir)

    files = Files()
    definition_file = [files.get_value("NAMELIST_DEFINITION_FILE", {"component": "drv"})]

    user_definition = os.path.join(user_xml_dir, "namelist_definition_drv.xml")
    if os.path.isfile(user_definition):
        definition_file = [user_definition]

    # Create the namelist generator object - independent of instance
    nmlgen = NamelistGenerator(case, definition_file)

    # create cplconf/namelist
    infile_text = ""
    if case.get_value('COMP_ATM') == 'cam':
        # cam is actually changing the driver namelist settings
        cam_config_opts = case.get_value("CAM_CONFIG_OPTS")
        if "aquaplanet" in cam_config_opts:
            infile_text = "aqua_planet = .true. \n aqua_planet_sst = 1"

    user_nl_file = os.path.join(caseroot, "user_nl_cpl")
    namelist_infile = os.path.join(confdir, "namelist_infile")
    create_namelist_infile(case, user_nl_file, namelist_infile, infile_text)
    infile = [namelist_infile]

    # create the files drv_in, drv_flds_in and seq_maps.rc
    _create_drv_namelists(case, infile, confdir, nmlgen, files)

    # create the files comp_modelio.nml where comp = [atm, lnd...]
    _create_component_modelio_namelists(case, files)

    # copy drv_in, drv_flds_in, seq_maps.rc and all *modio* fiels to rundir
    rundir = case.get_value("RUNDIR")

    shutil.copy(os.path.join(confdir,"drv_in"), rundir)
    drv_flds_in = os.path.join(caseroot, "CaseDocs", "drv_flds_in")
    if os.path.isfile(drv_flds_in):
        shutil.copy(drv_flds_in, rundir)

    shutil.copy(os.path.join(confdir,"seq_maps.rc"), rundir)

    for filename in glob.glob(os.path.join(confdir, "*modelio*")):
        shutil.copy(filename, rundir)

###############################################################################
def _main_func():
    caseroot = parse_input(sys.argv)

    with Case(caseroot) as case:
        buildnml(case, caseroot, "drv")

if __name__ == "__main__":
    _main_func()
