#!/usr/bin/python3
# coding: utf-8

# Copyright (c) 2023 NVIDIA Corporation.  All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# The views and conclusions contained in the software and documentation are those
# of the authors and should not be interpreted as representing official policies,
# either expressed or implied, of the FreeBSD Project.

import time
import json
from collections import defaultdict
import os

import argparse

sys_fs_prefix = '/to/be/initialized/'
sys_fs_parent_folder = '/sys/class/hwmon'
set_perf_core=''#'taskset -c 1'
dev="/dev/mst/mt41692_pciconf0"

llt_int2ext_name_dict = {'gdc_bank0_rd_req':'BANK0_RD','gdc_bank1_rd_req':'BANK1_RD',
                         'gdc_bank0_wr_req':'BANK0_WR','gdc_bank1_wr_req':'BANK1_WR',
                         'clock': 'CLOCK'}

llt_miss_int2ext_name_dict = {'gdc_miss_machine_rd_req':'RD_MISS',
                              'gdc_miss_machine_wr_req':'WR_MISS',
                              'clock': 'CLOCK'}

mss_int2ext_name_dict = {'skylib_ddn_tx_flits':'DDR_RD',
                         'skylib_ddn_rx_flits':'DDR_WR',
                        'clock': 'CLOCK'}

def module_ctrs_to_hierarchical_dicts(llt_ctrs_dict,int2ext_name_dict,show=False):
    d01 = defaultdict(dict)
    d10 = defaultdict(dict)
    for k, v in llt_ctrs_dict.items():
        d01[k[0]].update({int2ext_name_dict[k[1]]:int(v,16)})
        d10[int2ext_name_dict[k[1]]].update({k[0]:int(v,16)})
    if show:
        print(json.dumps(dict(d01),indent=4))
        print('===')
        print(json.dumps(d10,indent=4))
    return d01, d10


module2range = {'llt':range(8),'llt_miss':range(8),'mss':range(2)}

module2intEvents_ids_dict =  {'llt': {'gdc_bank0_rd_req': {'event_id': 84, 'ctr_id': 1}, 'gdc_bank1_rd_req': {'event_id': 123, 'ctr_id': 2}, 'gdc_bank0_wr_req': {'event_id': 85, 'ctr_id': 3}, 'gdc_bank1_wr_req': {'event_id': 124, 'ctr_id': 4}}, 'llt_miss': {'gdc_miss_machine_rd_req': {'event_id': 0, 'ctr_id': 1}, 'gdc_miss_machine_wr_req': {'event_id': 1, 'ctr_id': 2}}, 'mss': {'skylib_ddn_tx_flits': {'event_id': 1, 'ctr_id': 1}, 'skylib_ddn_rx_flits': {'event_id': 6, 'ctr_id': 2}}}

def pmc_perf_ctrs_collect(set_perf_core, _module2intEvents_ids_dict, _module2range, module_type, recording_sleep_time_secs, verbose=False):
    global sys_fs_prefix
    results_dict = {}    
    zeros_dict = {}    
    
    for module_idx in _module2range[module_type]:
        enable_path=os.path.join(sys_fs_prefix,module_type+f'{module_idx}','enable')
        cmd = f"{set_perf_core} echo {hex(0)} > {enable_path}"
        #if verbose: print(f'cmd: {cmd}') 
        os.popen(cmd);
        #if verbose: print('verify:', os.popen(f"cat {enable_path}").read())
    
    for module_idx in _module2range[module_type]:
        clock_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','count_clock')
        cmd = f"{set_perf_core} echo 1 > {clock_path}"
        #if verbose: print(f'cmd: {cmd}')
        os.popen(cmd)
        #if verbose: print('verify:', os.popen(f"cat {clock_path}").read())
        counter_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','counter'+f"0")
        cmd = f"{set_perf_core} echo {hex(0)} > {counter_path}"
        #if verbose: print(f'cmd: {cmd}')
        os.popen(cmd)
        #if verbose: print('verify:', os.popen(f"cat {counter_path}").read())
        for ctr_idx in range(1,16):
            counter_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','counter'+f"{ctr_idx}")
            cmd = f"{set_perf_core} echo {hex(0)} > {counter_path}"
            #if verbose: print(f'cmd: {cmd}')
            os.popen(cmd)
            #if verbose: print('verify:', os.popen(f"cat {counter_path}").read())
            
            event_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','event'+f"{ctr_idx}")
            cmd = f"{set_perf_core} echo {hex(0)} > {event_path}"
            os.popen(cmd)

        for event_name, event_ids in _module2intEvents_ids_dict[module_type].items():
            counter_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','counter'+f"{event_ids['ctr_id']}")
            event_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','event'+f"{event_ids['ctr_id']}")
            cmd = f"{set_perf_core} echo {hex(0)} > {counter_path}"
            #if verbose: print(f'cmd: {cmd}')
            os.popen(cmd)
            #if verbose: print('verify:', os.popen(f"cat {counter_path}").read())
            cmd = f"{set_perf_core} echo {hex(event_ids['event_id'])} > {event_path}"
            #if verbose: print(f'cmd: {cmd}')
            os.popen(cmd)
            #if verbose: print('verify:', os.popen(f"cat {event_path}").read())
           
        #clock_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','count_clock')
        #cmd = f"{set_perf_core} echo 1 > {clock_path}"
        #if verbose: print(f'cmd: {cmd}')
        #os.popen(cmd)
        #if verbose: print('verify:', os.popen(f"cat {clock_path}").read())
 
        for event_name, event_ids in _module2intEvents_ids_dict[module_type].items():
            counter_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','counter'+f"{event_ids['ctr_id']}")
            zeros_dict[(f'{module_type}{module_idx}',event_name)] = os.popen(f"{set_perf_core} cat {counter_path}").read()
        
        #print(f'zeros_dict: {zeros_dict}')
    
        if verbose: print((f"module {module_type}{module_idx}: STARTING .."))
        enable_path=os.path.join(sys_fs_prefix,module_type+f'{module_idx}','enable')
        cmd = f"{set_perf_core} echo {hex(1)} > {enable_path}"
        #if verbose: print(f'cmd: {cmd}')
        os.popen(cmd);         
        time.sleep(recording_sleep_time_secs)
        cmd = f"{set_perf_core} echo {hex(0)} > {enable_path}"
        #if verbose: print(f'cmd: {cmd}')
        os.popen(cmd);
        if verbose: print((f"module {module_type}{module_idx}: STOPPED"))
        
        for event_name, event_ids in _module2intEvents_ids_dict[module_type].items():
            counter_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','counter'+f"{event_ids['ctr_id']}")
            cmd = f"{set_perf_core} cat {counter_path}"
            #if verbose: print(f'cmd: {cmd}')
            results_dict[(f'{module_type}{module_idx}',event_name)] = os.popen(cmd).read()
            #if verbose: print('verify result:',results_dict[(f'{module_type}{module_idx}',event_name)] )

        counter_path = os.path.join(sys_fs_prefix,module_type+f'{module_idx}','counter'+f"0")
        cmd = f"{set_perf_core} cat {counter_path}"
        #if verbose: print(f'cmd: {cmd}')
        results_dict[(f'{module_type}{module_idx}','clock')] = os.popen(cmd).read()
        if verbose: print('---')
    return results_dict,zeros_dict


def main():
    global sys_fs_prefix

    parser = argparse.ArgumentParser(#prog = 'PROG',
                                     description="sysfs access to performance counters in BF3",
                                     epilog = 'Text at the bottom of help')

    parser.add_argument('-s', '--sleep',   help="sample period in seconds", default=1, dest="sleep",type=float)
    parser.add_argument('-v', '--verbose', action='store_true',default=False)
 
    args = parser.parse_args()
    print(f'args: {args}')
    if args.sleep > 1:
        print(f'requested sleep time is more than 1 sec ({args.sleep}), setting it to 1 sec')
    recording_sleep_time_secs=args.sleep if args.sleep <=1 else 1
    verbose=args.verbose

    for subfolder in [ent for ent in os.popen(f'{set_perf_core} ls /sys/class/hwmon').read().split('\n') ]:
        if ''==subfolder: continue
        subfolder_path = os.path.join(sys_fs_parent_folder,subfolder)
        name_path = os.path.join(subfolder_path,'name')
        if os.path.isfile(name_path) and 'bfperf' in os.popen(f"{set_perf_core} cat {name_path}").read():
                sys_fs_prefix = subfolder_path
    
    print(f'sysfs folder: {sys_fs_prefix}')
   
    cmd_w_t = f'cat {os.path.join(sys_fs_prefix, "clock_measure", "REFERENCE_WINDOW_WIDTH_PLL_T1")}'
    cmd_c_t = f'cat {os.path.join(sys_fs_prefix, "clock_measure", "FMON_CLK_LAST_COUNT_PLL_T1")}'
    cmd_w_n = f'cat {os.path.join(sys_fs_prefix, "clock_measure", "REFERENCE_WINDOW_WIDTH_PLL_N1")}'
    cmd_c_n = f'cat {os.path.join(sys_fs_prefix, "clock_measure", "FMON_CLK_LAST_COUNT_PLL_N1")}'
    #print(cmd)
    window_t = int(os.popen(cmd_w_t).read(),16)
    count_t = int(os.popen(cmd_c_t).read(),16)
    window_n = int(os.popen(cmd_w_n).read(),16)
    count_n = int(os.popen(cmd_c_n).read(),16)
    #print ((window_t, count_t, window_n, count_n))

    freqs_dict ={'t1': round((count_t/window_t)*156.25,2), 'n1': round((count_n/window_n)*156.25,2)} 

    if verbose: print(f'freqs_dict: {freqs_dict}')

    pmc_llt_results_dict, _      = pmc_perf_ctrs_collect(set_perf_core, module2intEvents_ids_dict, module2range, 'llt', recording_sleep_time_secs,verbose)
    pmc_llt_miss_results_dict, _ = pmc_perf_ctrs_collect(set_perf_core, module2intEvents_ids_dict, module2range, 'llt_miss', recording_sleep_time_secs,verbose)
    pmc_mss_results_dict, _      = pmc_perf_ctrs_collect(set_perf_core, module2intEvents_ids_dict, module2range, 'mss', recording_sleep_time_secs,verbose)

    pmc_llt_mt_dict,      pmc_llt_tm_dict      = module_ctrs_to_hierarchical_dicts(pmc_llt_results_dict,     llt_int2ext_name_dict);
    pmc_llt_miss_mt_dict, pmc_llt_miss_tm_dict = module_ctrs_to_hierarchical_dicts(pmc_llt_miss_results_dict,llt_miss_int2ext_name_dict);
    pmc_mss_mt_dict,      pmc_mss_tm_dict      = module_ctrs_to_hierarchical_dicts(pmc_mss_results_dict,     mss_int2ext_name_dict);

    pmc_res_dict = {'llt': {'freq_mhz': freqs_dict['t1'], 'data': pmc_llt_mt_dict},
                    'llt_miss': {'freq_mhz': freqs_dict['t1'], 'data': pmc_llt_miss_mt_dict},
                     'mss': {'freq_mhz': freqs_dict['n1'], 'data': pmc_mss_mt_dict}}

    print(" === pmc_res === ")
    print(json.dumps(pmc_res_dict,indent=4))


    print(" === pmc_res_bw ===") 
    pmc_bw_dict = {k:{kk: {f'{kkk}_Gbps':round(v['freq_mhz']*(1/1000)*(64*8)*vvv/vv['CLOCK'],2) for kkk,vvv in vv.items() if kkk !='CLOCK'} for kk,vv in v['data'].items() if 0!=vv['CLOCK']} for k,v in pmc_res_dict.items()}

    print(json.dumps(pmc_bw_dict,indent=4))
    print(" === pmc_res_bw_agg ===")
    pmc_res_bw_agg_dict = {k:{key: round(sum([the_dict[key] for the_dict in v.values()]),2) for key in list(v.values())[0].keys()} for k,v in pmc_bw_dict.items()}
    pmc_res_bw_agg_dict['llt']['RD_Gbps'] = round(pmc_res_bw_agg_dict['llt']['BANK0_RD_Gbps']+pmc_res_bw_agg_dict['llt']['BANK1_RD_Gbps'],2)
    pmc_res_bw_agg_dict['llt']['WR_Gbps'] = round(pmc_res_bw_agg_dict['llt']['BANK0_WR_Gbps']+pmc_res_bw_agg_dict['llt']['BANK1_WR_Gbps'],2)
    for ctr_name in ['BANK0_RD_Gbps', 'BANK1_RD_Gbps', 'BANK0_WR_Gbps', 'BANK1_WR_Gbps']:
        del pmc_res_bw_agg_dict['llt'][ctr_name]
    print(json.dumps(pmc_res_bw_agg_dict,indent=4))

if __name__ == "__main__":
    main()
