#!/usr/bin/env python3

#System
import time
import sys
import os
from enum import IntEnum

#BrainStem
import brainstem
from brainstem import _BS_C  #Gives access to aProtocolDef.h constants. 
from brainstem.result import Result

#Local files
from argument_parser import *

#Generic Utilities
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, "../utilities"))
sys.path.insert(0, parent_dir)
from brainstem_helpers import *
from poll_until import poll_until

#Power Delivery Utilities
pd_dir = os.path.abspath(os.path.join(script_dir, "../utilities/powerdelivery"))
sys.path.insert(0, pd_dir)
from bs_pd_packet_filtering import (
    validate_rdo_handshake, 
    clear_logger_packets,
)


#=============================================================================
# Constants
#=============================================================================
BAIL_COUNTER = 20
VOLTAGE_TOLERANCE_PERCENT = 5  # USB PD spec allows ±5% for fixed PDOs
MAX_TEST_RETRIES = 1  # Number of retries for failed tests

#=============================================================================
# Exit Codes
#=============================================================================
class ExitCode(IntEnum):
    """Exit codes for the RDO Test CLI script."""
    SUCCESS = 0
    CONNECTION_ERROR = 1
    CAPABILITY_CHECK_FAILED = 2
    PDO_TEST_FAILED = 3
    RDO_TEST_FAILED = 4
    PORT_DISABLE_FAILED = 5
    PORT_ENABLE_FAILED = 6
    UNKNOWN = 255


EXIT_CODE_DESCRIPTIONS = {
    ExitCode.SUCCESS: "All tests passed",
    ExitCode.CONNECTION_ERROR: "Failed to connect to device",
    ExitCode.CAPABILITY_CHECK_FAILED: "Device does not support required capabilities",
    ExitCode.PDO_TEST_FAILED: "One or more PDO tests failed",
    ExitCode.RDO_TEST_FAILED: "One or more RDO tests failed",
    ExitCode.PORT_DISABLE_FAILED: "Failed to disable port",
    ExitCode.PORT_ENABLE_FAILED: "Failed to enable port",
    ExitCode.UNKNOWN: "Unknown error occurred",
}


class ProgramExit(Exception):
    """
    Custom exception to signal program exit with a specific code.
    
    This exception is used instead of sys.exit() to allow context managers
    to properly clean up before the program exits.
    """
    def __init__(self, code, message=None):
        self.code = code
        self.message = message
        super().__init__(message)


def exit_with_code(code, additional_message=None):
    """Print exit code message and raise ProgramExit exception.
    
    Raises ProgramExit instead of calling sys.exit() directly to ensure
    context managers can properly clean up.
    """
    if isinstance(code, ExitCode):
        message = "EXIT CODE %d: %s" % (code.value, EXIT_CODE_DESCRIPTIONS[code])
        if additional_message:
            message += " - %s" % additional_message
        print(message, file=sys.stderr)
        raise ProgramExit(code, message)
    else:
        print("EXIT CODE %d: Unknown exit code" % code, file=sys.stderr)
        raise ProgramExit(code, additional_message)


#=============================================================================
# Device Capability Check
#=============================================================================


def verify_test_port(stem, port_index):
    """
    Verify that the test port is valid for RDO testing.
    
    Checks:
    - Port is a downstream port (not upstream/control)
    - Port is not the input power source
    
    Exits with error if validation fails.
    """
    # Check that the port is downstream
    port_entity = brainstem.entity.Port(stem, port_index)
    result = port_entity.getDataRole()
    
    if basic_error_handling(stem, result, "Could not get data role for port %d" % port_index):
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "getDataRole port %d" % port_index)
    
    if result.value != _BS_C.portDataRole_Downstream_Value:
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, 
                       "Port %d is not a downstream port (role=%d) and cannot be tested" % (port_index, result.value))
    
    print("Port %d: Data role validated as downstream" % port_index)
    
    # Check that the port is not the input power source
    system = brainstem.entity.System(stem, 0)
    result = system.getInputPowerSource()
    
    if basic_error_handling(stem, result, "Could not get input power source"):
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "getInputPowerSource")
    
    if result.value == port_index:
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, 
                       "Port %d is the input power source and cannot be tested." % port_index)
    
    print("Port %d: Verified not the input power source (power source=%d)" % (port_index, result.value))




#=============================================================================
# Test Functions
#=============================================================================
def test_voltage(stem, port_index, expected_voltage):
    """
    Checks the current voltage vs the expected voltage using VOLTAGE_TOLERANCE_PERCENT.
    
    Returns:
        tuple: (passed: bool, actual_voltage: int, voltage_min: int, voltage_max: int)
    """
    port_entity = brainstem.entity.Port(stem, port_index)
    voltage_result = port_entity.getVbusVoltage()

    if voltage_result.error:
        return (False, 0, 0, 0)

    # Calculate tolerance based on percentage of expected voltage
    tolerance = expected_voltage * VOLTAGE_TOLERANCE_PERCENT / 100
    voltage_min = int(expected_voltage - tolerance)
    voltage_max = int(expected_voltage + tolerance)

    passed = voltage_min <= voltage_result.value <= voltage_max
    return (passed, voltage_result.value, voltage_min, voltage_max)


def wait_for_voltage_to_establish(stem, port_index, expected_voltage):
    """Polls until voltage reaches expected value within tolerance."""
    def check_voltage():
        passed, actual, v_min, v_max = test_voltage(stem, port_index, expected_voltage)
        return (passed, (actual, v_min, v_max))

    result = poll_until(check_voltage, timeout=BAIL_COUNTER * 0.2, interval=0.2)
    
    actual, v_min, v_max = result.value if result.value else (0, 0, 0)
    
    if result.timed_out:
        print("Voltage out of range: %duV (expected: %duV ±%d%% = %duV to %duV)" % (
            actual, expected_voltage, VOLTAGE_TOLERANCE_PERCENT, v_min, v_max))
        return 1
    
    print("Voltage: %duV (expected: %duV ±%d%% = %duV to %duV)" % (
        actual, expected_voltage, VOLTAGE_TOLERANCE_PERCENT, v_min, v_max))
    return 0


def is_pd_established(stem, port_index):
    """
    Check if PD communication has been established by confirming a valid 
    RDO exists (either local when sinking or remote when sourcing).
    """
    pd_entity = brainstem.entity.PowerDelivery(stem, port_index)

    #Local RDO - This will be set if the port is SINKING power.
    rdo_local = pd_entity.getRequestDataObject(_BS_C.powerdeliveryPartnerLocal)
    if not basic_error_handling(stem, rdo_local) and rdo_local.value != 0:
        return 0

    #Remote RDO - This will be set if the port is SOURCING power.
    rdo_remote = pd_entity.getRequestDataObject(_BS_C.powerdeliveryPartnerRemote)
    if not basic_error_handling(stem, rdo_remote) and rdo_remote.value != 0:
        return 0

    return 1


def wait_for_pd_to_establish(stem, port_index):
    """Polls until PD communication is established or timeout."""
    def check_pd():
        error = is_pd_established(stem, port_index)
        return (error == 0, error)

    result = poll_until(check_pd, timeout=BAIL_COUNTER * 0.2, interval=0.2)
    
    if result.timed_out:
        return 1
    return 0


def wait_for_rdo_to_set(stem, port_index, rdo):
    """Wait for RDO to be set and verified successfully."""
    pd_entity = brainstem.entity.PowerDelivery(stem, port_index)

    def set_rdo():
        err_result = pd_entity.setRequestDataObject(rdo)
        return (not err_result, err_result)

    result = poll_until(set_rdo, timeout=BAIL_COUNTER * 0.1, interval=0.1)
    if result.timed_out:
        return 1

    return 0


def test_rdo(stem, logger, port_index, rdo, expected_voltage):
    """Test a specific RDO configuration."""
    pd_entity = brainstem.entity.PowerDelivery(stem, port_index)

    err_result = pd_entity.setPowerRole(_BS_C.powerdeliveryPowerRoleSink)
    if basic_error_handling(stem, err_result, "Failed to set Power Role"):
        return 1

    # Clear logger of any stale PD packets before starting the RDO transaction
    if logger:
        clear_logger_packets(logger)

    error = wait_for_rdo_to_set(stem, port_index, rdo)
    if error:
        print("Failed to set RDO %08X" % rdo)
        return 1

    # If the user has the PD Logging software feature we can confirm that the RDO was accepted
    # by inspecting the PD logging traffic and validating the specific RDO value was sent.
    if logger:
        error = validate_rdo_handshake(logger, rdo)
        if error == 1:
            print("RDO was rejected by the power supply")
            return 1
        elif error == 2:
            print("Timeout waiting for power supply ready")
            return 1

    error = wait_for_pd_to_establish(stem, port_index)
    if error:
        print("Failed to establish PD connection")
        return 1
    
    error = wait_for_voltage_to_establish(stem, port_index, expected_voltage)
    if error:
        print("Voltage is not as expected")
        return 1

    #TODO: Test loading of the RDO via External Load
    #    Enable Rail for "port_index"
    #    Check for expected current.
    #    Disable Rail for "port_index"

    return 0


def toggle_port(stem, port_index, settle_time=0.2):
    """Disable and re-enable a port to reset the connection."""
    port_entity = brainstem.entity.Port(stem, port_index)
    
    err = port_entity.setEnabled(False)
    if basic_error_handling(stem, err, "Failed to disable port %d" % port_index):
        exit_with_code(ExitCode.PORT_DISABLE_FAILED, "port %d" % port_index)
    time.sleep(settle_time)
    
    err = port_entity.setEnabled(True)
    if basic_error_handling(stem, err, "Failed to enable port %d" % port_index):
        exit_with_code(ExitCode.PORT_ENABLE_FAILED, "port %d" % port_index)
    time.sleep(settle_time)


def test_dut_pdo(stem, port_index, pdo_index, expected_pdo):
    """Test a specific PDO value."""
    pd_entity = brainstem.entity.PowerDelivery(stem, port_index)

    result = pd_entity.getPowerDataObject(_BS_C.powerdeliveryPartnerRemote, _BS_C.powerdeliveryPowerRoleSource, pdo_index)

    if basic_error_handling(stem, result, "Failed to get PDO %d on port %d" % (pdo_index, port_index)):
        return 1

    print("PDO:%d: 0x%08X : Expected PDO: 0x%08X" % (pdo_index, result.value, expected_pdo))

    if result.value != expected_pdo:
        print("FAIL - Unexpected Host PDO: 0x%08X : Expected PDO: 0x%08X" % (result.value, expected_pdo))
        return 1

    return 0


def run_with_retry(test_func, test_name, max_retries=MAX_TEST_RETRIES):
    """
    Run a test function with optional retries.
    
    Args:
        test_func: Callable that returns 0 on success, non-zero on failure
        test_name: Name of the test for logging
        max_retries: Number of retries after initial failure (default: MAX_TEST_RETRIES)
    
    Returns:
        0 if test passed (on any attempt), non-zero if all attempts failed
    """
    result = test_func()
    if result == 0:
        return 0
    
    for retry in range(max_retries):
        print("    Retry %d/%d: %s" % (retry + 1, max_retries, test_name))
        result = test_func()
        if result == 0:
            return 0
    
    return result

#=============================================================================
# CLI Manager - Context manager for cleanup
#=============================================================================
class CLI_Manager:
    """Context manager that handles cleanup of the stem and restores device state."""
    def __init__(self, test_port):
        self.stem = None  # brainstem.module.Module object
        self.test_port = test_port  # Port index being tested
        self.original_power_role = None  # Original power role to restore on exit
        self.logger = None  # PDChannelLogger object
        self.pd_logging_enabled = False

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.restore_port_power_role()

        if self.logger:
            self.logger.setEnabled(False)
            self.logger = None

        if self.stem:
            self.stem.disconnect()
            self.stem = None

        return False  # Ensure exception propagates

    def check_device_capability(self):
        """
        Checks to see if the device can execute the required commands for the test.
        Uses hasUEI to verify the device supports the necessary operations.
        Sets self.pd_logging_enabled based on device capabilities.
        Uses self.test_port for the port index.
        """
        stem = self.stem
        port_index = self.test_port

        #---------------------------------------------------------------------------------
        #cmdPORT Checks
        #---------------------------------------------------------------------------------
        #Check how many Port Entities this device has.
        result = stem.classQuantity(_BS_C.cmdPORT)
        if basic_error_handling(stem, result, "Could not acquire class quantity for cmdPORT"):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "cmdPORT class quantity")

        if port_index >= result.value:
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, 
                           "Port %d is out of range. Device has %d ports." % (port_index, result.value))

        #Check that we can enable/disable the port
        result = stem.hasUEI(_BS_C.cmdPORT, _BS_C.portPortEnabled, port_index, (_BS_C.ueiOPTION_SET))
        if basic_error_handling(stem, result, "Cannot enable/disable port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d setEnabled" % port_index)

        #Check that we can get vbus voltage
        result = stem.hasUEI(_BS_C.cmdPORT, _BS_C.portVbusVoltage, port_index, (_BS_C.ueiOPTION_GET))
        if basic_error_handling(stem, result, "Cannot get vbus voltage on port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d getVbusVoltage" % port_index)

        #Check that we can get data role
        result = stem.hasUEI(_BS_C.cmdPORT, _BS_C.portDataRole, port_index, (_BS_C.ueiOPTION_GET))
        if basic_error_handling(stem, result, "Cannot get data role on port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d getDataRole" % port_index)
        #---------------------------------------------------------------------------------

        #---------------------------------------------------------------------------------
        #cmdSYSTEM Checks
        #---------------------------------------------------------------------------------
        #Check that we can get input power source
        result = stem.hasUEI(_BS_C.cmdSYSTEM, _BS_C.systemInputPowerSource, 0, (_BS_C.ueiOPTION_GET))
        if basic_error_handling(stem, result, "Cannot get input power source"):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "getInputPowerSource")
        #---------------------------------------------------------------------------------

        #---------------------------------------------------------------------------------
        #cmdPOWERDELIVERY Checks
        #---------------------------------------------------------------------------------
        #Check how many PowerDelivery Entities this device has.
        result = stem.classQuantity(_BS_C.cmdPOWERDELIVERY)
        if basic_error_handling(stem, result, "Could not acquire class quantity for cmdPOWERDELIVERY"):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "cmdPOWERDELIVERY class quantity")

        if port_index >= result.value:
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, 
                           "Port %d is out of range for PowerDelivery. Device has %d PD ports." % (port_index, result.value))

        #Check that we can get RDO's
        result = stem.hasUEI(_BS_C.cmdPOWERDELIVERY, _BS_C.powerdeliveryRequestDataObject, port_index, (_BS_C.ueiOPTION_GET))
        if basic_error_handling(stem, result, "Cannot get RDOs on port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d getRDO" % port_index)

        #Check that we can set RDO's
        result = stem.hasUEI(_BS_C.cmdPOWERDELIVERY, _BS_C.powerdeliveryRequestDataObject, port_index, (_BS_C.ueiOPTION_SET))
        if basic_error_handling(stem, result, "Cannot set RDOs on port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d setRDO" % port_index)

        #Check that we can get PDO's
        result = stem.hasUEI(_BS_C.cmdPOWERDELIVERY, _BS_C.powerdeliveryPowerDataObject, port_index, (_BS_C.ueiOPTION_GET))
        if basic_error_handling(stem, result, "Cannot get PDOs on port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d getPDO" % port_index)

        #Check that we can set power role
        result = stem.hasUEI(_BS_C.cmdPOWERDELIVERY, _BS_C.powerdeliveryPowerRole, port_index, (_BS_C.ueiOPTION_SET))
        if basic_error_handling(stem, result, "Cannot set power role on port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d setPowerRole" % port_index)

        #Check that we can get power role
        result = stem.hasUEI(_BS_C.cmdPOWERDELIVERY, _BS_C.powerdeliveryPowerRole, port_index, (_BS_C.ueiOPTION_GET))
        if basic_error_handling(stem, result, "Cannot get power role on port %d" % port_index):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d getPowerRole" % port_index)

        #Check that we can enable PD Logging
        #In this example this is optional, but offers a more complete test.
        pd = brainstem.entity.PowerDelivery(stem, port_index)
        result =pd.get_UEI8(_BS_C.powerdeliveryLogEnable)
        if result.error == Result.NO_ERROR:
            self.pd_logging_enabled = True
        else:
            print("This device does not have the PD Logging software feature. That portion of this example will be bypassed.")
            self.pd_logging_enabled = False
        #---------------------------------------------------------------------------------

    def save_port_power_role(self):
        """
        Save the port's power role for later restoration.
        Sets self.original_power_role.
        Uses self.test_port for the port index.
        """
        pd_entity = brainstem.entity.PowerDelivery(self.stem, self.test_port)
        result = pd_entity.getPowerRole()
        
        if basic_error_handling(self.stem, result, "Could not get current power role for port %d" % self.test_port):
            print("Warning: Will not be able to restore power role on exit")
            self.original_power_role = None
        else:
            self.original_power_role = result.value
            print("Port %d: Saved original power role: %d" % (self.test_port, self.original_power_role))

    def restore_port_power_role(self):
        """
        Restore the port's power role that was saved by save_port_power_role.
        Uses self.stem, self.test_port, and self.original_power_role.
        """
        if self.stem is None or self.test_port is None or self.original_power_role is None:
            return
        
        print("\n--- Restoring port power role ---")
        pd_entity = brainstem.entity.PowerDelivery(self.stem, self.test_port)
        err = pd_entity.setPowerRole(self.original_power_role)
        
        if basic_error_handling(self.stem, err, "Failed to restore power role"):
            print("Warning: Failed to restore power role to %d" % self.original_power_role)
        else:
            print("Restored power role to %d" % self.original_power_role)


#=============================================================================
# Main
#=============================================================================
def main(argv):
    exit_code = ExitCode.UNKNOWN
    
    try:
        print("Provided Arguments:")
        print(argv)
        arg_parser = CustomArgumentParser(argv)

        with CLI_Manager(arg_parser.test_port) as cli:
            #Setup
            #/////////////////////////////////////////////////////////////////////
            # Note: This code uses the base Module class instead of a specific device type.
            # The Module class is the base class for all BrainStem Objects like the USBHub3c, USBHub3p, USBCSwitch etc.
            # This allows our code to be more generic; however, we don't really know what we are or what
            # we are capable of so we must do a handful of capability checks.
            cli.stem = create_and_connect_stem(arg_parser.sn)
            if cli.stem is None:
                exit_with_code(ExitCode.CONNECTION_ERROR, "Serial: 0x%08X" % arg_parser.sn if arg_parser.sn else "first found")
            
            cli.check_device_capability()
            verify_test_port(cli.stem, cli.test_port)
            cli.save_port_power_role()

            if cli.pd_logging_enabled:
                cli.logger = create_logger(cli.stem, cli.test_port) #Create Logger object for given port
            #/////////////////////////////////////////////////////////////////////

            #Work
            #/////////////////////////////////////////////////////////////////////
            expected_host_pdos = arg_parser.expected_host_pdos
            expected_host_pdos_voltage = arg_parser.expected_host_pdos_voltage
            expected_rdos = arg_parser.expected_rdos

            pdo_failures = 0
            rdo_failures = 0

            print("\n--- Testing PDO's ---")
            for x in range(len(expected_host_pdos)):
                #+1 PDO start from index 1
                result = test_dut_pdo(cli.stem, cli.test_port, x+1, expected_host_pdos[x])
                if result:
                    print("    Error: testing PDO: %d - 0x%08X" % (x+1, expected_host_pdos[x]))
                    pdo_failures += 1
                else:
                    print("    Success testing PDO: %d" % (x+1))

            print("\n")

            toggle_port(cli.stem, cli.test_port)

            print("--- Testing RDO's ---")
            for x in range(len(expected_rdos)):
                rdo_index = x + 1  # RDO indices start from 1
                test_name = "RDO %d (0x%08X)" % (rdo_index, expected_rdos[x])

                # Some power supplies can't keep up with rapid cycling through RDO's.
                result = run_with_retry(
                    lambda idx=x: test_rdo(cli.stem, cli.logger, cli.test_port, expected_rdos[idx], expected_host_pdos_voltage[idx]),
                    test_name
                )
                if result:
                    print("    Error: testing %s" % test_name)
                    rdo_failures += 1
                else:
                    print("    Success testing RDO: %d" % rdo_index)
            #/////////////////////////////////////////////////////////////////////

            # Print summary
            print("\n" + "=" * 60)
            print("TEST RESULTS SUMMARY")
            print("=" * 60)
            print("PDO Tests: %d/%d passed" % (len(expected_host_pdos) - pdo_failures, len(expected_host_pdos)))
            print("RDO Tests: %d/%d passed" % (len(expected_rdos) - rdo_failures, len(expected_rdos)))
            print("=" * 60)

            if pdo_failures > 0:
                exit_code = ExitCode.PDO_TEST_FAILED
            elif rdo_failures > 0:
                exit_code = ExitCode.RDO_TEST_FAILED
            else:
                exit_code = ExitCode.SUCCESS

    except ProgramExit as e:
        exit_code = e.code

    # Exit at the end, after context manager cleanup is complete
    return exit_code


if __name__ == '__main__':
    sys.exit(main(sys.argv))

