#!/usr/bin/env python3

#System
import time
import sys
import os
from enum import IntEnum

#BrainStem
import brainstem
from brainstem import _BS_C #Gives access to aProtocolDef.h constants. 
from brainstem.result import Result

#Local files
from argument_parser import *

#Generic Utilities
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, "../utilities"))
sys.path.insert(0, parent_dir)
from brainstem_helpers import *
from poll_until import poll_until


#=============================================================================
# Device Speed
#=============================================================================
class DeviceSpeed(IntEnum):
    """USB device speed enumeration matching DeviceNode speed values."""
    UNKNOWN = 0
    LOW_SPEED = 1           # 1.5M
    FULL_SPEED = 2          # 12M
    HIGH_SPEED = 3          # 480M
    SUPER_SPEED = 4         # 5G
    SUPER_SPEED_PLUS = 5    # 10G


#=============================================================================
# Constants
#=============================================================================
DATA_RATE_CONFIGURATIONS = [
    # USB 2.0 (HS)                                  USB 3.0 (SS)                                      Expected Speed                Label
    (_BS_C.usbsystemDataHSMaxDatarate_None,         _BS_C.usbsystemDataSSMaxDatarate_SuperSpeedPlus,  DeviceSpeed.SUPER_SPEED_PLUS, "10G"),
    (_BS_C.usbsystemDataHSMaxDatarate_None,         _BS_C.usbsystemDataSSMaxDatarate_SuperSpeed,      DeviceSpeed.SUPER_SPEED,      "5G"),
    (_BS_C.usbsystemDataHSMaxDatarate_HighSpeed,    _BS_C.usbsystemDataSSMaxDatarate_None,            DeviceSpeed.HIGH_SPEED,       "480M"),
    (_BS_C.usbsystemDataHSMaxDatarate_FullSpeed,    _BS_C.usbsystemDataSSMaxDatarate_None,            DeviceSpeed.FULL_SPEED,       "12M"),
    
    # At this time no Acroname products support this configuration. It exists for completeness.
    (_BS_C.usbsystemDataHSMaxDatarate_LowSpeed,     _BS_C.usbsystemDataSSMaxDatarate_None,            DeviceSpeed.LOW_SPEED,        "1.5M"),
]

# Errors that indicate a range/index error.
ALLOWED_RANGE_ERRORS = [
    Result.INDEX_RANGE_ERROR, 
    Result.RANGE_ERROR, 
    Result.UNIMPLEMENTED_ERROR, 
    Result.PARAMETER_ERROR,
]

# Errors that indicate an unsupported configuration 
UNSUPPORTED_CONFIG_ERRORS = [
    Result.PARAMETER_ERROR, 
    Result.RANGE_ERROR,
    Result.UNIMPLEMENTED_ERROR,
]

MAX_ACRONAME_PORTS = 30  # A value that is greater than any Acroname port count. 
CONTROL_PORT_MIN_MICRO_VOLTS = 4000000  # Minimum voltage (μV) to consider control port connected


#=============================================================================
# Exit Codes
#=============================================================================
class ExitCode(IntEnum):
    """Exit codes for the Speed Test script."""
    SUCCESS = 0
    CONNECTION_ERROR = 1
    CAPABILITY_CHECK_FAILED = 2
    CONFIGURATION_ERROR = 3
    TEST_FAILED = 4
    UNKNOWN = 255


EXIT_CODE_DESCRIPTIONS = {
    ExitCode.SUCCESS: "All speed tests passed",
    ExitCode.CONNECTION_ERROR: "Failed to connect to device",
    ExitCode.CAPABILITY_CHECK_FAILED: "Device does not support required capabilities",
    ExitCode.CONFIGURATION_ERROR: "Failed to configure hub data rate settings",
    ExitCode.TEST_FAILED: "One or more speed tests failed",
    ExitCode.UNKNOWN: "Unknown error occurred",
}


class ProgramExit(Exception):
    """
    Custom exception to signal program exit with a specific code.
    
    This exception is used instead of sys.exit() to allow context managers
    to properly clean up before the program exits.
    """
    def __init__(self, code, message=None):
        self.code = code
        self.message = message
        super().__init__(message)


def exit_with_code(code, additional_message=None):
    """Print exit code message and raise ProgramExit exception.
    
    Raises ProgramExit instead of calling sys.exit() directly to ensure
    context managers can properly clean up.
    """
    if isinstance(code, ExitCode):
        message = "EXIT CODE %d: %s" % (code.value, EXIT_CODE_DESCRIPTIONS[code])
        if additional_message:
            message += " - %s" % additional_message
        print(message, file=sys.stderr)
        raise ProgramExit(code, message)
    else:
        print("EXIT CODE %d: Unknown exit code" % code, file=sys.stderr)
        raise ProgramExit(code, additional_message)


#=============================================================================
# Configuration Result
#=============================================================================
class ConfigResult(IntEnum):
    """Result of attempting to configure a data rate setting."""
    SUCCESS = 0
    UNSUPPORTED = 1  # Device does not support this configuration
    ERROR = 2        # Other error occurred



def check_device_capability(stem, test_ports=None):
    """
    Checks to see if the device can execute the required commands for the test.
    Uses hasUEI to verify the device supports the necessary operations.
    
    Args:
        stem: The BrainStem module connection
        test_ports: List of ports to validate. If None, only checks USBSystem capabilities.
    """
    #---------------------------------------------------------------------------------
    #cmdUSBSYSTEM Checks
    #---------------------------------------------------------------------------------
    #Check how many USBSystem Entities this device has.
    result = stem.classQuantity(_BS_C.cmdUSBSYSTEM)
    if basic_error_handling(stem, result, "Could not acquire class quantity for cmdUSBSYSTEM"):
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "cmdUSBSYSTEM class quantity")

    if result.value < 1:
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "Device does not have USBSystem capabilities")

    #Check that we can set HS max data rate
    result = stem.hasUEI(_BS_C.cmdUSBSYSTEM, _BS_C.usbsystemDataHSMaxDatarate, 0, (_BS_C.ueiOPTION_SET))
    if basic_error_handling(stem, result, "Device does not support setting HS max data rate"):
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "setDataHSMaxDatarate")

    #Check that we can set SS max data rate
    result = stem.hasUEI(_BS_C.cmdUSBSYSTEM, _BS_C.usbsystemDataSSMaxDatarate, 0, (_BS_C.ueiOPTION_SET))
    if basic_error_handling(stem, result, "Device does not support setting SS max data rate"):
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "setDataSSMaxDatarate")
    #---------------------------------------------------------------------------------

    #---------------------------------------------------------------------------------
    #cmdPORT Checks
    #---------------------------------------------------------------------------------
    #Check how many Port Entities this device has.
    result = stem.classQuantity(_BS_C.cmdPORT)
    if basic_error_handling(stem, result, "Could not acquire class quantity for cmdPORT"):
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "cmdPORT class quantity")

    # If no test_ports provided, skip per-port validation
    if test_ports is None:
        return

    #Check that each test port is valid
    for port in test_ports:
        if port >= result.value:
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, 
                           "Port %d is out of range. Device has %d ports." % (port, result.value))

        #Check that we can enable/disable the port
        result_uei = stem.hasUEI(_BS_C.cmdPORT, _BS_C.portPortEnabled, port, (_BS_C.ueiOPTION_SET))
        if basic_error_handling(stem, result_uei, "Cannot enable/disable port %d" % port):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d setEnabled" % port)

        #Check that we can get data speed
        result_uei = stem.hasUEI(_BS_C.cmdPORT, _BS_C.portDataSpeed, port, (_BS_C.ueiOPTION_GET))
        if basic_error_handling(stem, result_uei, "Cannot get data speed on port %d" % port):
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "port %d getDataSpeed" % port)
    #---------------------------------------------------------------------------------


def verify_control_port_connected(stem):
    """
    Verify that the control port is connected with adequate voltage.
    
    Finds the port with data role portDataRole_Control_Value and checks
    that it has voltage > CONTROL_PORT_MIN_MICRO_VOLTS.
    
    Exits with error if control port is not found or voltage is too low.
    """
    print("\n--- Verifying Control Port Connection ---")
    
    control_port_idx = None
    
    # Find the control port
    for port_idx in range(MAX_ACRONAME_PORTS):
        port_entity = brainstem.entity.Port(stem, port_idx)
        result = port_entity.getDataRole()
        
        # Stop if we've exceeded valid ports
        if result.error in ALLOWED_RANGE_ERRORS:
            break
        
        if basic_error_handling(stem, result, "Error getting data role for port %d" % port_idx):
            continue
        
        if result.value == _BS_C.portDataRole_Control_Value:
            control_port_idx = port_idx
            print("Found control port: Port %d" % port_idx)
            break
    
    if control_port_idx is None:
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "No control port found on device")
    
    # Check voltage on control port
    port_entity = brainstem.entity.Port(stem, control_port_idx)
    result = port_entity.getVbusVoltage()
    
    if basic_error_handling(stem, result, "Error getting voltage for control port %d" % control_port_idx):
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "Cannot read control port voltage")
    
    if result.value < CONTROL_PORT_MIN_MICRO_VOLTS:
        exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "The control port does not appear to be connected. Please connect it and try again." )
    

def validate_ports_not_upstream(stem, ports):
    """
    Validate that none of the specified ports are upstream ports.
    
    Exits with error if any port is an upstream port.
    
    Args:
        stem: The BrainStem module connection
        ports: List of port indices to validate
    """
    for port in ports:
        port_entity = brainstem.entity.Port(stem, port)
        result = port_entity.getDataRole()
        
        if basic_error_handling(stem, result, "Warning: Could not get data role for port %d" % port):
            continue
        
        if result.value == _BS_C.portDataRole_Upstream_Value:
            exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, 
                           "Port %d is an upstream port and cannot be tested" % port)
        
        print("Port %d: Data role validated (role=%d)" % (port, result.value))


def discover_downstream_ports(stem):
    """
    Discover all downstream ports on the hub by iterating through ports
    and checking their data role.
    
    Iterates from port 0 up to MAX_ACRONAME_PORTS, stopping when an index 
    out of range error is encountered.
    
    Returns:
        List of port indices that are downstream ports.
    """
    downstream_ports = []
    for port_idx in range(MAX_ACRONAME_PORTS):
        try:
            port_entity = brainstem.entity.Port(stem, port_idx)
            result = port_entity.getDataRole()

            # Check if this is an index out of range error (indicates we've exceeded valid ports)
            if result.error in ALLOWED_RANGE_ERRORS:
                print("Port discovery complete: found %d total ports" % port_idx)
                break
            
            # Handle other errors
            if basic_error_handling(stem, result, "Warning: Error getting data role for port %d" % port_idx):
                continue
            
            if result.value == _BS_C.portDataRole_Downstream_Value:
                downstream_ports.append(port_idx)
                print("Port %d: Downstream" % port_idx)
            else:
                print("Port %d: Not downstream (role=%d)" % (port_idx, result.value))
                
        except Exception as e:
            # Catch any unexpected exceptions and stop iteration
            print("Port discovery stopped at port %d: %s" % (port_idx, str(e)))
            break
    
    print("Discovered %d downstream ports: %s" % (len(downstream_ports), downstream_ports))
    return downstream_ports


def filter_ports_with_devices(stem, ports):
    """
    Filter the list of ports to only include those with connected devices.
    
    Checks each port's data speed - if it has a value (> 0), a device is enumerated.
    Also captures the maximum speed each device is capable of.
    
    Returns:
        tuple: (list of port indices with devices, dict mapping port -> max DeviceSpeed)
    """
    ports_with_devices = []
    device_max_speeds = {}
    
    for port in ports:
        port_entity = brainstem.entity.Port(stem, port)
        result = port_entity.getDataSpeed()
        
        if basic_error_handling(stem, result, "Port %d: Error getting data speed" % port):
            continue
        
        if result.value > 0:
            speed = decode_data_speed(result.value)
            print("Port %d: Device connected at %s" % (port, speed.name))
            ports_with_devices.append(port)
            device_max_speeds[port] = speed
        else:
            print("Port %d: No device enumerated" % port)
    
    print("Found %d ports with devices: %s" % (len(ports_with_devices), ports_with_devices))
    return ports_with_devices, device_max_speeds


def decode_data_speed(speed_value):
    """
    Decode raw port data speed bits into a DeviceSpeed enum.
    
    The raw speed value is a bitmask from the hub's getDataSpeed() call.
    This function translates it to the DeviceNode speed enumeration format.
    """
    # Test USB 3.0 first so that we find the highest speed first.
    # USB Hubs enumerate at both USB 2.0 and USB 3.0 speeds.
    if speed_value & (1 << _BS_C.portDataSpeed_Connected_3p0_Bit):
        print("Device enumerated as USB 3.0")

        if speed_value & (1 << _BS_C.portDataSpeed_ss_5G_Bit):
            print("Data speed is 5G")
            return DeviceSpeed.SUPER_SPEED
        elif speed_value & (1 << _BS_C.portDataSpeed_ss_10G_Bit):
            print("Data speed is 10G")
            return DeviceSpeed.SUPER_SPEED_PLUS

    elif speed_value & (1 << _BS_C.portDataSpeed_Connected_2p0_Bit):
        print("Device enumerated as USB 2.0")

        if speed_value & (1 << _BS_C.portDataSpeed_ls_1p5M_Bit):
            print("Data speed is 1.5M")
            return DeviceSpeed.LOW_SPEED
        elif speed_value & (1 << _BS_C.portDataSpeed_fs_12M_Bit):
            print("Data speed is 12M")
            return DeviceSpeed.FULL_SPEED
        elif speed_value & (1 << _BS_C.portDataSpeed_hs_480M_Bit):
            print("Data speed is 480M")
            return DeviceSpeed.HIGH_SPEED

    print("Device enumerated as unknown")
    return DeviceSpeed.UNKNOWN


def compare_device_speeds(hub_speed, os_speed, expected_speed):
    """
    Compare the speed reported by the hub vs the OS vs the expected speed.
    
    Returns:
        True if all speeds match expected, False otherwise.
    """
    if os_speed == DeviceSpeed.UNKNOWN:
        print("Cannot compare: OS speed not available.")
        return False
    
    if hub_speed != os_speed:
        print("FAIL: Hub/OS mismatch - Hub: %s, OS: %s" % (hub_speed.name, os_speed.name))
        return False
    
    if hub_speed != expected_speed:
        print("FAIL: Speed mismatch - Got: %s, Expected: %s" % (hub_speed.name, expected_speed.name))
        return False
    
    print("PASS: Device speed matches expected: %s" % (expected_speed.name))
    return True


def get_device_speed_from_os(port, serial_number=0, timeout=8.0):
    """
    Get the device speed as reported by the OS for a specific port.

    Polls the hub until a device enumerates or timeout is reached.
    """
    def check_devices():
        # Note: getDownstreamDevices() is a static function and does not require a stem connection;
        # however, it only returns devices which are physically connected to Acroname hubs.
        devices = brainstem.discover.getDownstreamDevices()
        for device in devices.value:
            if device.hub_port != port:
                continue
            if serial_number != 0 and device.hub_serial_number != serial_number:
                continue
            return True, DeviceSpeed(device.speed)
        return False, DeviceSpeed.UNKNOWN

    result = poll_until(check_devices, timeout=timeout, interval=0.3)
    
    if result.timed_out:
        print("Timeout waiting for OS to enumerate device")
        return DeviceSpeed.UNKNOWN

    return result.value


def prepare_hub_max_datarate_configuration(stem, ports, hs_speed, ss_speed, settle_time=0.5):
    """
    Prepare the hub for a speed test by configuring data rate limits.
    
    Disables all specified ports, configures the HS/SS max data rates, then re-enables all ports.
    Since max data rate settings are hub-wide, all ports must be cycled together.
    
    Returns:
        ConfigResult.SUCCESS: Configuration applied successfully
        ConfigResult.UNSUPPORTED: Device does not support this data rate configuration
        ConfigResult.ERROR: Other error occurred
    """
    usb_system = brainstem.entity.USBSystem(stem, 0)
    # Disable all ports
    for port in ports:
        port_entity = brainstem.entity.Port(stem, port)
        err = port_entity.setEnabled(False)
        if basic_error_handling(stem, err, "Failed to disable port %d" % port):
            return ConfigResult.ERROR

    # Configure hub-wide max HS data rates
    err = usb_system.setDataHSMaxDatarate(hs_speed)
    if err in UNSUPPORTED_CONFIG_ERRORS:
        print("Configuration not supported by device (HS): %s" % Result.getErrorText(err))
        # Re-enable ports before returning
        for port in ports:
            brainstem.entity.Port(stem, port).setEnabled(True)
        return ConfigResult.UNSUPPORTED
    if basic_error_handling(stem, err, "Failed to set HS max data rate"):
        return ConfigResult.ERROR

    # Configure hub-wide max SS data rates
    err = usb_system.setDataSSMaxDatarate(ss_speed)
    if err in UNSUPPORTED_CONFIG_ERRORS:
        print("Configuration not supported by device (SS): %s" % Result.getErrorText(err))
        # Re-enable ports before returning
        for port in ports:
            brainstem.entity.Port(stem, port).setEnabled(True)
        return ConfigResult.UNSUPPORTED
    if basic_error_handling(stem, err, "Failed to set SS max data rate"):
        return ConfigResult.ERROR

    # Re-enable all ports
    for port in ports:
        port_entity = brainstem.entity.Port(stem, port)
        err = port_entity.setEnabled(True)
        if basic_error_handling(stem, err, "Failed to enable port %d" % port):
            return ConfigResult.ERROR

    time.sleep(settle_time)  # Allow time for devices to enumerate.
    
    return ConfigResult.SUCCESS


def get_device_speed_from_hub(stem, port, timeout=5.0, max_retries=5):
    """
    Get the device speed as reported by the hub for a specific port.
    
    Polls the hub until a device enumerates or timeout/max retries is reached.
    """
    retry_count = 0  # Will be used as a "nonlocal" variable.
    port_entity = brainstem.entity.Port(stem, port)
    
    def check_speed():
        nonlocal retry_count  # Uses scope variable to keep state through multiple calls.
        speed_result = port_entity.getDataSpeed()
        
        if speed_result.error:
            print("Error getting data speed: %s - retry %d" % (speed_result.error, retry_count))
            retry_count += 1
            if retry_count >= max_retries:
                return True, speed_result  # Stop polling due to max retries
            return False, speed_result
        elif speed_result.value > 0:
            print("Device enumeration detected")
            return True, speed_result
        return False, speed_result

    result = poll_until(check_speed, timeout=timeout, interval=0.3)
    
    if result.timed_out:
        print("Timeout waiting for device to enumerate")

    speed_value = result.value.value if result.value else 0
    return decode_data_speed(speed_value)


def print_results_table(results, test_ports, speed_labels):
    """
    Print a summary table of test results.
    
    Results can be:
        True  - Test passed
        False - Test failed
        None  - Configuration not supported by device (N/A)
    """
    # Calculate column widths
    port_col_width = 6  # "Port X"
    speed_col_width = max(len(label) for label in speed_labels) + 2
    
    # Print header
    print("\n" + "=" * 60)
    print("TEST RESULTS SUMMARY")
    print("=" * 60)
    
    # Print column headers
    header = "Port".ljust(port_col_width) + " | "
    header += " | ".join(label.center(speed_col_width) for label in speed_labels)
    print(header)
    print("-" * len(header))
    
    # Print each row (port)
    all_passed = True
    for port in test_ports:
        row = ("%s" % port).ljust(port_col_width) + " | "
        cells = []
        for label in speed_labels:
            passed = results.get((port, label), "--")
            if passed is True:
                cells.append("PASS".center(speed_col_width))
            elif passed is False:
                cells.append("FAIL".center(speed_col_width))
                all_passed = False
            elif passed is None:
                # Configuration not supported by device - not a failure
                cells.append("N/A".center(speed_col_width))
            else:
                cells.append("--".center(speed_col_width))
        row += " | ".join(cells)
        print(row)
    
    print("=" * 60)
    
    # Print overall result
    if all_passed:
        print("OVERALL: ALL TESTS PASSED")
    else:
        print("OVERALL: SOME TESTS FAILED")
    print("=" * 60)
    
    return all_passed


#Context manager that handles cleanup of the stem.
class CLI_Manager:
    def __init__(self):
        self.stem = None  # brainstem.module.Module object
        self.test_ports = None
        self.device_max_speeds = {}  # Maps port -> max DeviceSpeed

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.stem and self.test_ports:
            # Restore hub to default configuration before disconnecting
            print("\n--- Restoring hub to default configuration ---")
            prepare_hub_max_datarate_configuration(
                self.stem, 
                self.test_ports, 
                _BS_C.usbsystemDataHSMaxDatarate_HighSpeed,
                _BS_C.usbsystemDataSSMaxDatarate_SuperSpeedPlus
            )

        if self.stem:
            print("\n--- Disconnecting from device ---")
            self.stem.disconnect()
            self.stem = None

        return False  # Ensure exception propagates


def main(argv):
    exit_code = ExitCode.UNKNOWN
    
    try:
        print("Provided Arguments:")
        print(argv)
        arg_parser = CustomArgumentParser(argv)

        with CLI_Manager() as cli:
            #Setup
            #/////////////////////////////////////////////////////////////////////
            #---------------------------------------------------------------------
            # Note: This code uses the base Module class instead of a specific device type.
            # The Module class is the base class for all BrainStem Objects like the USBHub3c, USBHub3p, USBCSwitch etc.
            # This allows our code to be more generic; however, we don't really know what we are or what
            # we are capable of so we must do a handful of capability checks.
            
            # 1. Connect to device
            cli.stem = create_and_connect_stem(arg_parser.sn)
            if cli.stem is None:
                exit_with_code(ExitCode.CONNECTION_ERROR, "Serial: 0x%08X" % arg_parser.sn if arg_parser.sn else "first found")
            
            # 2. Check basic device capabilities before calling any other APIs
            check_device_capability(cli.stem)
            
            # 3. Verify control port is connected before proceeding
            verify_control_port_connected(cli.stem)
            
            # 4. Determine candidate ports based on mode
            if arg_parser.automatic:
                print("\n--- Automatic Port Discovery ---")
                candidate_ports = discover_downstream_ports(cli.stem)
                if not candidate_ports:
                    exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "No downstream ports found on device")
            else:
                # Manual mode - validate specified ports aren't upstream
                validate_ports_not_upstream(cli.stem, arg_parser.test_ports)
                candidate_ports = arg_parser.test_ports
            
            # 5. Ensure ports are enabled and the max data rates are set.
            print("\n--- Resetting hub to default configuration ---")
            prepare_hub_max_datarate_configuration(
                cli.stem,
                candidate_ports,
                _BS_C.usbsystemDataHSMaxDatarate_HighSpeed,
                _BS_C.usbsystemDataSSMaxDatarate_SuperSpeedPlus,
                settle_time=3.0 # Some device might need more or less time to enumerate.
            )

            # 6. Filter to ports with connected devices (common to both modes)
            ports_with_devices, device_max_speeds = filter_ports_with_devices(cli.stem, candidate_ports)
            if not ports_with_devices:
                exit_with_code(ExitCode.CAPABILITY_CHECK_FAILED, "No devices connected to downstream ports")
            
            cli.test_ports = ports_with_devices
            cli.device_max_speeds = device_max_speeds
            
            # 7. Validate capabilities for the specific test ports
            check_device_capability(cli.stem, cli.test_ports)
            
            print("\n--- Setup complete: testing ports %s ---" % cli.test_ports)
            #---------------------------------------------------------------------
            #/////////////////////////////////////////////////////////////////////

            #Work
            #/////////////////////////////////////////////////////////////////////
            # Track test results: (port, speed_label) -> True/False/None
            # None indicates the configuration is unsupported by the device
            results = {}
            speed_labels = [config[3] for config in DATA_RATE_CONFIGURATIONS]

            # Loop through the data rate configurations
            for hs_speed, ss_speed, expected_speed, speed_label in DATA_RATE_CONFIGURATIONS:
                config_result = prepare_hub_max_datarate_configuration(cli.stem, cli.test_ports, hs_speed, ss_speed)
                
                if config_result == ConfigResult.ERROR:
                    exit_with_code(ExitCode.CONFIGURATION_ERROR, "Failed to prepare hub for %s" % speed_label)
                
                if config_result == ConfigResult.UNSUPPORTED:
                    # Mark all ports as unsupported for this configuration
                    print("--- Skipping %s tests (configuration not supported by BrainStem device) ---" % speed_label)
                    for port in cli.test_ports:
                        results[(port, speed_label)] = None  # None = unsupported/N/A
                    continue
                
                # Test each port with the current configuration
                for port in cli.test_ports:
                    # Check if the device is capable of the expected speed
                    if port in cli.device_max_speeds:
                        device_max = cli.device_max_speeds[port]
                        if expected_speed > device_max:
                            print("\n--- Skipping Port %d @ %s (device max speed is %s) ---" % (port, speed_label, device_max.name))
                            results[(port, speed_label)] = None  # None = N/A
                            continue
                    
                    print("\n--- Testing Port %d @ %s (expecting %s) ---" % (port, speed_label, expected_speed.name))
                    
                    hub_speed = get_device_speed_from_hub(cli.stem, port)
                    os_speed = get_device_speed_from_os(port, arg_parser.sn)

                    success = compare_device_speeds(hub_speed, os_speed, expected_speed)
                    results[(port, speed_label)] = success
            #/////////////////////////////////////////////////////////////////////

            # Print summary table
            all_passed = print_results_table(results, cli.test_ports, speed_labels)
            
            if all_passed:
                exit_code = ExitCode.SUCCESS
            else:
                exit_code = ExitCode.TEST_FAILED

    except ProgramExit as e:
        exit_code = e.code

    # Exit at the end, after context manager cleanup is complete
    return exit_code


if __name__ == '__main__':
    sys.exit(main(sys.argv))

