#!/usr/bin/env python

"""
This utility allows the CIME user to view and modify a case's PE layout.

Interpreted format sequences are:"
%%  a literal %
%C  the component name
%T  the task count for the component
%H  the thread count for the component
%R  the PE root for the component

Standard format extensions, such as a field length and padding are supported.
Example:
  For NTASKS_ATM=480, NTHREADS_ATM=2 and ROOTPE_ATM, \"%C: %06T/%+H\" will show
ATM: 000480/+2
Note that since %R was omitted, the root PE was not shown.
Python dictionary-format strings are also supported. For instance,
--format "{C:4}", will print the component name padded to 4 spaces.
=======================
If this tool is missing any feature that you need, please go to github and
create an issue (https://github.com/ESMCI/cime/issues). Detailed
information is available at https://github.com/ESMCI/cime.

    COPYRIGHT AND LICENSE

        This library is free software; you can redistribute it and/or modify
        it under the same terms as Python itself.
"""

from standard_script_setup import *

from CIME.case import Case
from CIME.utils import expect, convert_to_string
import textwrap
import sys
import re

logger = logging.getLogger("xmlquery")

###############################################################################
def parse_command_line(args):
###############################################################################
    parser = argparse.ArgumentParser(description="Display and/or change the case's PE layout." ,
                                     formatter_class=argparse.RawDescriptionHelpFormatter,
                                     epilog=textwrap.dedent(__doc__) )
    CIME.utils.setup_standard_logging_options(parser)

    # Set command line options
    parser.add_argument("-format" , "--format",
                        default="%4C: %6T/%6H; %6R",
                        help="Format the PE layout items for each component (see below)")

    parser.add_argument("-set-nasks" ,  "--set-ntasks", default=None,
                        help="Total number of tasks to set for the case")

    parser.add_argument("-set-nthrds" ,  "--set-nthrds",
                        "-set-nthreads", "--set-nthreads", default=None,
                        help="Number of threads to set for all components")

    parser.add_argument("-header" ,  "--header",
                        default="Comp  NTASKS  NTHRDS  ROOTPE",
                        help="Custom header for PE layout display")

    parser.add_argument("-no-header", "--no-header", default=False , action="store_true" ,
                        help="Do not print any PE layout header")

    parser.add_argument("-caseroot" , "--caseroot", default=os.getcwd(),
                        help="Case directory to reference")

    args = CIME.utils.parse_args_and_handle_standard_logging_options(args, parser)

    if (args.no_header):
        args.header = None
    # End if

    return args.format, args.set_ntasks, args.set_nthrds, args.header, args.caseroot
# End def parse_command_line


###############################################################################
def get_value_as_string(case, var, attribute=None, resolved=False, subgroup=None):
###############################################################################
    thistype = case.get_type_info(var)
    value = case.get_value(var, attribute=attribute, resolved=resolved, subgroup=subgroup)
    if value is not None and thistype:
        value = convert_to_string(value, thistype, var)
    return value

###############################################################################
def format_pelayout(comp, ntasks, nthreads, rootpe, arg_format):
###############################################################################
    """
    Format the PE layout information for each component, using a default format,
    or using the arg_format input, if it exists.
    """
    subs = { 'C': comp, 'T': ntasks, 'H': nthreads, 'R': rootpe }
    layout_str = re.sub(r"%([0-9]*)C", r"{C:\1}", arg_format)
    layout_str = re.sub(r"%([-+0-9]*)T", r"{T:\1}", layout_str)
    layout_str = re.sub(r"%([-+0-9]*)H", r"{H:\1}", layout_str)
    layout_str = re.sub(r"%([-+0-9]*)R", r"{R:\1}", layout_str)
    layout_str = layout_str.format(**subs)
    return layout_str
# End def format_pelayout

###############################################################################
def print_pelayout(case, ntasks, nthreads, rootpes, arg_format, header):
###############################################################################
    """
    Print the PE layout information for each component, using the format,
     if it exists.
    """
    comp_classes = case.get_values("COMP_CLASSES")

    if (header is not None):
        print header
    # End if
    for comp in comp_classes:
        print format_pelayout(comp, ntasks[comp], nthreads[comp], rootpes[comp], arg_format)
    # End for

# End def print_pelayout

###############################################################################
def gather_pelayout(case):
###############################################################################
    """
    Gather the PE layout information for each component
    """
    ntasks = {}
    nthreads = {}
    rootpes = {}
    comp_classes = case.get_values("COMP_CLASSES")

    for comp in comp_classes:
        ntasks[comp]   = int(case.get_value("NTASKS_"+comp))
        nthreads[comp] = int(case.get_value("NTHRDS_"+comp))
        rootpes[comp]  = int(case.get_value("ROOTPE_"+comp))
    # End for
    return ntasks, nthreads, rootpes
# End def gather_pelayout

###############################################################################
def set_nthreads(case, nthreads):
###############################################################################
    comp_classes = case.get_values("COMP_CLASSES")

    for comp in comp_classes:
        case.set_value("NTHRDS", nthreads, comp)
    # End for
# End def set_nthreads

###############################################################################
def modify_ntasks(case, new_tot_tasks):
###############################################################################
    comp_classes = case.get_values("COMP_CLASSES")
    new_tasks = {}
    new_roots = {}
    curr_tot_tasks = 0

    # First, gather current task and root pe info
    curr_tasks, _, curr_roots = gather_pelayout(case)

    # How many tasks are currently being used?
    for comp in comp_classes:
        if ((curr_tasks[comp] + curr_roots[comp]) > curr_tot_tasks):
            curr_tot_tasks = curr_tasks[comp] + curr_roots[comp]
        # End if
    # End for

    if (new_tot_tasks != curr_tot_tasks):
        # Compute new task counts and root pes
        for comp in comp_classes:
            new_tasks[comp] = curr_tasks[comp] * new_tot_tasks / curr_tot_tasks
            new_roots[comp] = curr_roots[comp] * new_tot_tasks / curr_tot_tasks
        # End for

        # Check for valid recomputation
        for comp in comp_classes:
            expect(new_tasks[comp] * curr_tot_tasks / new_tot_tasks == curr_tasks[comp],
                   "Task change invalid for {}" % comp)
            expect(new_roots[comp] * curr_tot_tasks / new_tot_tasks == curr_roots[comp],
                   "Root PE change invalid for {}" % comp)
        # End for

        # We got this far? Go ahead and change PE layout
        for comp in comp_classes:
            case.set_value("NTASKS_"+comp, new_tasks[comp], comp)
            case.set_value("ROOTPE_"+comp, new_roots[comp], comp)
        # End for
    # End if (#tasks changed)
# End def modify_ntasks

###############################################################################
def _main_func():
###############################################################################
    if ("--test" in sys.argv):
        test_results = doctest.testmod(verbose=True)
        sys.exit(1 if test_results.failed > 0 else 0)

    # Initialize command line parser and get command line options
    arg_format, set_ntasks, set_nthrds, header, caseroot = parse_command_line(sys.argv)

    # Initialize case ; read in all xml files from caseroot
    with Case(caseroot) as case:
        if (set_nthrds is not None):
            set_nthreads(case, set_nthrds)
        # End if
        if (set_ntasks is not None):
            modify_ntasks(case, int(set_ntasks))
        # End if
        ntasks, nthreads, rootpes = gather_pelayout(case)
        print_pelayout(case, ntasks, nthreads, rootpes, arg_format, header)
    # End with

# End def _main_func

if (__name__ == "__main__"):
    _main_func()
# End if
