#System
import os
import sys
import time

# BrainStem
from brainstem import _BS_C  # Gives access to aProtocolDef.h constants.
from brainstem.pd_channel_logger import PDChannelLogger
from brainstem.result import Result
from brainstem.autoGen_PowerDelivery_Entity import PowerDelivery

#Local files
from pd_defs import *

#Generic Utilities
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, "../../utilities"))
sys.path.insert(0, parent_dir)
from utilities import *


def get_request_dict():
    """ Lookup table for PD request filtering and decoding. """
    return {
        _BS_C.pdRequestManufacturerInfoSop: [filter_tx_manufacturer_info,   filter_rx_manufacturer_info,    decode_pdRequestManufacturerInfoSop],
        _BS_C.pdRequestBatteryCapabilities: [filter_tx_battery_cap,         filter_rx_battery_cap,          decode_pdRequestBatteryCapabilities],
        _BS_C.pdRequestBatteryStatus:       [filter_tx_battery_status,      filter_rx_battery_status,       decode_pdRequestBatteryStatus],
    }


def wait_for_silence(logger, silence_time_seconds=2, max_time_to_wait_seconds=10):
    """
    This function returns when it has timed out or PD traffic has been
    silent for defined about of time.
    At the beginning of PD conversations some devices can be really chatty (Apple).
    During this time many devices will ignore/drop PD packets.
    """
    start_time = time.time()  # Record the start time
    last_time = start_time
    silence_time = 0

    result = logger.getPacket()
    while True:
        current_time = time.time() 
        elapsed_time = current_time - start_time 

        if result.error == Result.NOT_READY:
            silence_time += current_time - last_time
            if silence_time >= silence_time_seconds:
                return True

            time.sleep(.1)
        elif result.error == Result.NO_ERROR:
            silence_time = 0
        else:
            print("Unknown Error: %d", result.error)

        if elapsed_time >= max_time_to_wait_seconds: 
            print("Maximum time reached(silence)")
            break

        result = logger.getPacket()
        last_time = current_time

    return False


def clear_logger_packets(logger):
    """ Clear all pending packets from the logger. """
    while True:
        result = logger.getPackets()
        if result.error == Result.NOT_READY:
            break


def wait_for_pd_packet(logger, sop, filter_func, max_time_seconds=5):
    """
    Simple packet finder - finds first packet matching the filter function.
    This is the core building block for PD packet filtering.
    Note: This function is destructive in that it will consume all packets it processes.

    Find a PD packet matching the filter function.
    
    Args:
        logger: PDChannelLogger instance
        sop: SOP type to filter on
        filter_func: Callable(packet, sop) -> bool. Can be a simple function,
                     SequenceFilter, or any callable.
        max_time_seconds: Maximum time to wait for a matching packet
    
    Returns:
        Result containing the packet on success, None on timeout
    """
    start_time = time.time()
    result = logger.getPacket()
    
    while True:
        current_time = time.time()
        elapsed_time = current_time - start_time

        if result.error == Result.NO_ERROR:
            if filter_func(result.value, sop):
                return result
        else:
            time.sleep(0.1)

        if elapsed_time >= max_time_seconds:
            print("Maximum time reached (wait_for_pd_packet)")
            break

        result = logger.getPacket()

    return None


def pd_packet_filter(logger, sop, func_tx, func_rx, max_time_seconds=5):
    """
    Generic pd packet filter function which requires
    a transmit and receive function to succeed for success.
    This function ensures the desired packet is sent AND received.
    """
    seq_filter = SequenceFilter([func_tx, func_rx])
    return wait_for_pd_packet(logger, sop, seq_filter, max_time_seconds)


#////////////////////////////////////////////////////
#Generic Stateful Filter Classes
#////////////////////////////////////////////////////
class SequenceFilter:
    """
    A stateful filter that matches a sequence of filters in order.
    
    Each filter in the sequence must match before the next is considered.
    The filter returns True when ALL filters in the sequence have matched,
    OR when an abort condition is triggered (allowing early bail on known failures).
    
    Example:
        # Wait for Accept then PS_RDY, but bail immediately on Reject
        seq_filter = SequenceFilter(
            [filter_rx_rdo_accept, filter_rx_rdo_ps_rdy],
            abort_on=[filter_rx_rdo_reject]
        )
        result = wait_for_pd_packet(logger, sop, seq_filter)
        
        if seq_filter.aborted:
            print("Request was rejected")
        elif seq_filter.completed:
            print("Success!")
    
    Attributes:
        matched_packets: List of packets matched at each stage
        completed: True if all filters have matched successfully
        aborted: True if an abort condition was triggered
        abort_packet: The packet that triggered the abort, if any
    """
    def __init__(self, filters, abort_on=None):
        self.filters = list(filters)
        self.abort_filters = abort_on or []
        self.index = 0
        self.matched_packets = []
        self.aborted = False
        self.abort_packet = None
    
    def __call__(self, packet, sop):
        # Check abort conditaions first
        for abort_filter in self.abort_filters:
            if abort_filter(packet, sop):
                self.aborted = True
                self.abort_packet = packet
                return True  # Signal wait_for_pd_packet to stop
        
        # Progress through sequence
        if self.index < len(self.filters):
            if self.filters[self.index](packet, sop):
                self.matched_packets.append(packet)
                self.index += 1
                if self.index >= len(self.filters):
                    return True  # Sequence complete
        return False
    
    @property
    def completed(self):
        """ Returns True if all filters in the sequence have matched successfully. """
        return self.index >= len(self.filters) and not self.aborted
    
    def reset(self):
        """ Reset the filter state to start matching from the beginning. """
        self.index = 0
        self.matched_packets = []
        self.aborted = False
        self.abort_packet = None
#////////////////////////////////////////////////////


#////////////////////////////////////////////////////
#pdRequestBatteryCapabilities(14) - Battery Capabilities PD Message filters
#////////////////////////////////////////////////////
def filter_tx_battery_cap(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 1 = TX - Case for us sending the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_GET_BATTERY_CAP = 3
        if packet.direction == 1    and \
            header.messageType == 3 and \
            header.extended:

            return True

    return False


def filter_rx_battery_cap(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))

        #Direction 2 = RX - Case for receiving the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_BATTERY_CAPABILITIES = 5
        if packet.direction == 2    and \
            header.messageType == 5 and \
            header.extended:

            return True

    return False


def decode_pdRequestBatteryCapabilities(result):
    if len(result.payload) < 13:
        return Result(Result.SIZE_ERROR, dict())

    vendor_id = get_two_bytes_from_buffer(result.payload, 4)
    product_id = get_two_bytes_from_buffer(result.payload, 6)
    design_cap = get_two_bytes_from_buffer(result.payload, 8)
    last_full_cap = get_two_bytes_from_buffer(result.payload, 10)
    battery_type = result.payload[12]

    d = dict()
    d["Vendor ID"] = ("0x%04X" % (vendor_id))
    d["Product ID"] = ("0x%04X" % (product_id))
    d["Design Cap (WH)"] = (design_cap * 100) / 1000
    d["Last Full Cap (WH)"] = (last_full_cap * 100) / 1000
    d["Invalid"] = battery_type & 0x01
    d["Reserved"] = (battery_type & 0xFE) >> 1
    return Result(Result.NO_ERROR, d)

#////////////////////////////////////////////////////




#////////////////////////////////////////////////////
#pdRequestBatteryStatus(15) - Battery Status Info PD Message filters
#////////////////////////////////////////////////////
def filter_tx_battery_status(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 1 = TX - Case for us sending the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_GET_BATTERY_STATUS = 4
        if packet.direction == 1    and \
            header.messageType == 4 and \
            header.extended:
            return True

    return False


def filter_rx_battery_status(packet, sop):
    if len(packet.payload) == 6 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))

        #Direction 2 = RX - Case for receiving the pd packet
        # kPD_MESSAGE_DATA_TYPE_BATTERY_STATUS = 5
        if packet.direction == 2        and \
            header.messageType == 5     and \
            header.extended == False    and \
            header.numOfDataObjs == 1:
            return True

    return False


def decode_pdRequestBatteryStatus(result):
    if len(result.payload) == 6:
        data_object = get_four_bytes_from_buffer(result.payload, 2)
        bsdo = rule_to_battery_status_data_object(data_object)

        d = dict()

        d["State of Charge (WH)"] = None if bsdo.battery_pc == 0xFFFF else (bsdo.battery_pc * 100) / 1000
        d["Reserved 1"] = bsdo.reserved
        d["Invalid Battery Reference"] = bsdo.battery_info_bit_fields.invalid_battery_ref
        d["Battery Present"] = bsdo.battery_info_bit_fields.battery_is_present

        # Note:
        # typedef enum {
        #     kPDBatteryChargeCharging     = 0,
        #     kPDBatteryChargeDischarging  = 1,
        #     kPDBatteryChargeIdle         = 2,
        #     kPDBatteryChargeReserved     = 3,
        # } PD_AppBatteryState_t;
        d["Charging Status"] = bsdo.battery_info_bit_fields.battery_charging_status

        d["Reserved 2"] = bsdo.battery_info_bit_fields.reserved

        return Result(Result.NO_ERROR, d)

    return Result(Result.SIZE_ERROR, dict())
#////////////////////////////////////////////////////




#////////////////////////////////////////////////////
#pdRequestManufacturerInfoSop(16) - Manufacturer Info PD Message filters
#////////////////////////////////////////////////////
def filter_tx_manufacturer_info(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 1 = TX - Case for us sending the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_GET_MANUFACTURER_INFO = 6
        if packet.direction == 1    and \
            header.messageType == 6 and \
            header.extended:

            return True

    return False


def filter_rx_manufacturer_info(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))

        #Direction 2 = RX - Case for receiving the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_MANUFACTURER_INFO = 7
        if packet.direction == 2    and \
            header.messageType == 7 and \
            header.extended:

            return True

    return False


def decode_pdRequestManufacturerInfoSop(result):
    #2x bytes of header, 2x bytes extended header, 2x bytes VendorID, 2x bytes Product ID, string
    if len(result.payload) >= 8:
        vendor_id = get_two_bytes_from_buffer(result.payload, 4)
        product_id = get_two_bytes_from_buffer(result.payload, 6)
        mfg_string = ''.join(chr(x) for x in result.payload[8:len(result.payload)])

        d = dict()
        d["Vendor ID"] = ("0x%04X" % (vendor_id))
        d["Product ID"] = ("0x%04X" % (product_id))
        d["Mfg String"] = mfg_string
        return Result(Result.NO_ERROR, d)

    return Result(Result.SIZE_ERROR, dict())
#////////////////////////////////////////////////////




#////////////////////////////////////////////////////
#RDO Request and Response Filters
#////////////////////////////////////////////////////

def filter_tx_rdo_with_value(packet, sop, expected_rdo):
    """
    Filter for detecting an outgoing RDO that matches the expected value.
    
    Note: This function does not fit the normal structure of a filter function.
    In order to use it correctly you need to wrap it in a lambda function.
    Example: filter_tx_rdo = lambda pkt, sop: filter_tx_rdo_with_value(pkt, sop, expected_rdo)
    """
    if len(packet.payload) >= 6 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 1 = TX - Case for us sending the pd packet
        # messageType = kPD_MESSAGE_DATA_TYPE_REQUEST = 2
        if packet.direction == 1        and \
            header.messageType == 2     and \
            header.numOfDataObjs == 1   and \
            header.extended == False:
            
            # Extract RDO value from payload (bytes 2-5 after header)
            rdo_value = get_four_bytes_from_buffer(packet.payload, 2)
            print("RDO 0x%08X sent. Expected RDO 0x%08X - Match: %s" % (rdo_value, expected_rdo, rdo_value == expected_rdo))
            return rdo_value == expected_rdo

    return False


def filter_rx_rdo_accept(packet, sop):
    """
    Filter for detecting an incoming Accept control message.
    """
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 2 = RX - Case for receiving the pd packet
        # Accept is a control message (numOfDataObjs = 0)
        # messageType = kPD_MESSAGE_CONTROL_TYPE_ACCEPT = 3
        if packet.direction == 2        and \
            header.messageType == 3     and \
            header.numOfDataObjs == 0   and \
            header.extended == False:
            print("Accept received")
            return True

    return False


def filter_rx_rdo_reject(packet, sop):
    """
    Filter for detecting an incoming Reject control message.
    """
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 2 = RX - Case for receiving the pd packet
        # Reject is a control message (numOfDataObjs = 0)
        # messageType =kPD_MESSAGE_CONTROL_TYPE_REJECT = 4
        if packet.direction == 2        and \
            header.messageType == 4     and \
            header.numOfDataObjs == 0   and \
            header.extended == False:
            print("Reject received")
            return True

    return False


def filter_rx_rdo_ps_rdy(packet, sop):
    """
    Filter for detecting an incoming PS_RDY (Power Supply Ready) control message.
    """
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 2 = RX - Case for receiving the pd packet
        # PS_RDY is a control message (numOfDataObjs = 0)
        # messageType = kPD_MESSAGE_CONTROL_TYPE_PS_RDY = 6
        if packet.direction == 2        and \
            header.messageType == 6     and \
            header.numOfDataObjs == 0   and \
            header.extended == False:
            print("PS_RDY received")
            return True

    return False
#////////////////////////////////////////////////////



#////////////////////////////////////////////////////
#Note: Filtering is minimum here because dfu commands travel 
#through SOP'Debug and SOP''Debug commands. These are not common.
#////////////////////////////////////////////////////
def filter_tx_apple_dfu_vdm(packet, sop):
    #Direction 1 = TX - Case for us sending the pd packet
    if packet.sop == sop and packet.direction == 1:
        return True  # Set latch that indicate we saw our packet get sent.
    return False


def filter_rx_apple_dfu_vdm(packet, sop):
    #Direction 2 = RX - Case for receiving the pd packet
    if packet.sop == sop and packet.direction == 2:
        return True

    return False
#////////////////////////////////////////////////////



#////////////////////////////////////////////////////
# PD Sequence Validators.
#////////////////////////////////////////////////////

def validate_rdo_handshake(logger, expected_rdo, sop=0, max_time_seconds=5):
    """
    Wait for a successful RDO transaction (Request -> Accept -> PS_RDY) or detect rejection.
    
    Returns:
        0 RDO Accepted and Ready
        1 RDO Rejected
        2 Timeout
    """

    # TX filter validates the specific RDO value was sent
    filter_tx_rdo = lambda pkt, sop: filter_tx_rdo_with_value(pkt, sop, expected_rdo)

    # Wait for RDO, Accept + PS_RDY sequence, bail immediately on Reject
    filter = SequenceFilter(
        [filter_tx_rdo, filter_rx_rdo_accept, filter_rx_rdo_ps_rdy],
        abort_on=[filter_rx_rdo_reject]
    )
    
    result = wait_for_pd_packet(logger, sop, filter, max_time_seconds)
    
    if result is None:
        return 2  # Timeout
    
    if filter.aborted:
        return 1  # Rejected (bailed early)
    
    if filter.completed:
        return 0  # Success
    
    return 3  # Unknown error
#////////////////////////////////////////////////////