#!/usr/bin/env python

"""
Set CICE cppdefs
"""

import os, sys, re

CIMEROOT = os.environ.get("CIMEROOT")
if CIMEROOT is None:
    raise SystemExit("ERROR: must set CIMEROOT environment variable")
sys.path.append(os.path.join(CIMEROOT, "scripts", "Tools"))

from standard_script_setup import *

from CIME.utils import run_cmd_no_fail, expect
from CIME.utils import run_cmd
from CIME.case import Case
from CIME.buildnml import parse_input

import glob, shutil
logger = logging.getLogger(__name__)

###############################################################################
def buildcpp(case):
###############################################################################

    # determine cice_config_opts
    cice_config_opts = case.get_value('CICE_CONFIG_OPTS')

    # set ice grid settings
    nx = case.get_value("ICE_NX")
    ny = case.get_value("ICE_NY")

    ice_grid = case.get_value('ICE_GRID')
    if ice_grid == 'ar9v2':
        hgrid = 'ar9v1'
    elif ice_grid == 'ar9v4':
        hgrid = 'ar9v3'
    else:
        hgrid = ice_grid

    # determine decomposition xml variables if CICE_AUTO_DECOMP is true
    # - invoke generate_cice_decomp.pl and update env_build.xml settings to 
    # reflect changes in the configuration this will trigger 
    cice_auto_decomp = case.get_value("CICE_AUTO_DECOMP")

    pts_mode = case.get_value("PTS_MODE")
    if pts_mode:
        # explicitl set values for single column mode
        nx = 1
        ny = 1
        cice_auto_decomp = False
        case.set_value("CICE_BLCKX", 1)
        case.set_value("CICE_BLCKY", 1)
        case.set_value("CICE_MXBLCKS", 1)
        case.set_value("CICE_DECOMPTYPE", "cartesian")
        case.set_value("CICE_DECOMPSETTING", "square-ice")

    if cice_auto_decomp:
        ntasks_ice = case.get_value("NTASKS_ICE")
        nthrds_ice = case.get_value("NTHRDS_ICE")
        ninst_ice = case.get_value("NINST_ICE")
        ntasks = int(ntasks_ice / ninst_ice)
        srcroot = case.get_value("SRCROOT")

        cmd = os.path.join(srcroot, "components", "cice", "bld", "generate_cice_decomp.pl")
        command = "%s -ccsmroot %s -res %s -nx %s -ny %s -nproc %s -thrds %s -output %s "  \
                  % (cmd, srcroot, hgrid, nx, ny, ntasks, nthrds_ice, "all")
        rc, out, err = run_cmd(command)

        expect(rc==0,"Command %s failed rc=%d\nout=%s\nerr=%s"%(cmd,rc,out,err))
        if out is not None:
            logger.debug("     %s"%out)
        if err is not None:
            logger.debug("     %s"%err)

        config = out.split()
        if config[0] > 0:
            case.set_value("CICE_BLCKX", config[2])
            case.set_value("CICE_BLCKY", config[3])
            case.set_value("CICE_MXBLCKS",config[4])
            case.set_value("CICE_DECOMPTYPE", config[5])
            case.set_value("CICE_DECOMPSETTING", config[6])

    # set cice mode and cice_config_opts
    cice_mode = case.get_value("CICE_MODE")
        
    # set cice physics
    if "cice5" in cice_config_opts:
        phys = "cice5"
    elif "cice4" in cice_config_opts:
        phys = "cice4"

    if cice_mode == 'prescribed':
         case.set_value("CICE_DECOMPTYPE", "roundrobin")

    if cice_mode == 'prescribed':
        ntr_aero = 0
    else:
        ntr_aero = 3
    ntr_aero = set_nondefault_cpp(cice_config_opts, "ntr_aero", ntr_aero) 

    # set number of bgc tracers (valid values are 0->10)
    if cice_mode == 'prescribed':
        nbgclyr = 0
    else:
        nbgclyr = 3
    nbgclyr = set_nondefault_cpp(cice_config_opts, "nbgclyr", nbgclyr) 

    # set isotope tracer
    ntr_iso = 0
    ntr_iso = set_nondefault_cpp(cice_config_opts, "ntr_iso", ntr_iso) 

    # set age tracer
    trage = 1
    trage = set_nondefault_cpp(cice_config_opts, "trage", trage) 

    # set first year ice tracer (valid values are 0,1)
    trfy = 1
    trfy = set_nondefault_cpp(cice_config_opts, "trfy", trfy) 

    # set pond tracer (valid values are 0,1)
    trpnd = 1
    trpnd = set_nondefault_cpp(cice_config_opts, "trpnd", trpnd) 

    # set level ice tracer (valid values are 0,1)
    trlvl = 1
    trlvl = set_nondefault_cpp(cice_config_opts, "trlvl", trlvl) 

    # set brine tracer (valid values are 0,1)
    trbri = 0
    trbri = set_nondefault_cpp(cice_config_opts, "trbri", trbri) 

    # set skeletal layer tracer (valid values are 0,1)
    trbgcs = 0
    trbgcs = set_nondefault_cpp(cice_config_opts, "trbgcs", trbgcs) 

    # set number of ice layers
    if phys == "cice4":
        nicelyr = 4
    elif "ar9v" in hgrid:
        nicelyr = 7
    else:
        nicelyr = 8

    # set number of snow layers
    if phys == "cice4":
        nsnwlyr = 1
    else:
        nsnwlyr = 3

    # set number of ice categories 
    # NOTE that ICE_NCAT is used by both cice and pop - but is set by cice
    # and as a result it is assumed that the cice buildcpp is called 
    # BEFORE the pop buildcpp. This order is set by the xml variable
    # COMP_CLASSES in the driver config_component.xml file
    if cice_mode == 'prescribed':
        ncat = 1
    else:
        ncat = 5
    ncat = set_nondefault_cpp(cice_config_opts, "ncat", ncat) 
    case.set_value("ICE_NCAT",ncat)
    logger.debug("cice: number of ice categories (ncat) is %s" %ncat)

    # set decomposition block sizes
    cice_blckx   = case.get_value("CICE_BLCKX")
    cice_blcky   = case.get_value("CICE_BLCKY")
    cice_mxblcks = case.get_value("CICE_MXBLCKS")

    cice_cppdefs = " -DCESMCOUPLED -Dncdf -DNUMIN=11 -DNUMAX=99 " \
                   " -DNICECAT=%s -DNXGLOB=%s -DNYGLOB=%s -DNTRAERO=%s -DNTRISO=%s" \
                   " -DNBGCLYR=%s -DNICELYR=%s -DNSNWLYR=%s"  \
                   " -DTRAGE=%s -DTRFY=%s -DTRLVL=%s -DTRPND=%s -DTRBRI=%s -DTRBGCS=%s" \
                   %(ncat,nx,ny,ntr_aero,ntr_iso,nbgclyr,nicelyr,nsnwlyr,trage,trfy,trlvl,trpnd,trbri,trbgcs)

    # trigger RASM options with ar9v grid, otherwise set CESM options
    if "ar9v" in hgrid:
        cice_cppdefs = cice_cppdefs + "-DRASM_MODS"

    # determine cice_cppdefs used in build
    cice_cppdefs = cice_cppdefs + " -DBLCKX=%s -DBLCKY=%s -DMXBLCKS=%s"%(cice_blckx, cice_blcky, cice_mxblcks)

    # update the xml variable CICE_CPPDEFS with the above definition 
    case.set_value("CICE_CPPDEFS", cice_cppdefs)
    case.flush()

    return cice_cppdefs

###############################################################################
def set_nondefault_cpp(cice_config_opts, string, value):

    # overwrite value if need be
    if string in cice_config_opts:
        match = re.search(r"\s*-%s\s*(\d+)\s*"%string, cice_config_opts)
        if match is not None:
            value = match.group(1)
    return value

###############################################################################
def _main_func():

    caseroot = parse_input(sys.argv)
    with Case(caseroot) as case:
        cice_cppdefs = buildcpp(case)
    logger.info("CICE_CPPDEFS: %s" %cice_cppdefs)

if __name__ == "__main__":
    _main_func()
