#System
import os
import sys
import time

#BrainStem
from brainstem import _BS_C #Gives access to aProtocolDef.h constants. 
from brainstem.pd_channel_logger import PDChannelLogger
from brainstem.result import Result
from brainstem.autoGen_PowerDelivery_Entity import PowerDelivery

#Local files
from pd_defs import *

#Generic Utilities
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(script_dir, "../../utilities"))
sys.path.insert(0, parent_dir)
from utilities import *


#This function should be kept up to date with the implemented request filters within this file. 
def get_request_dict():
    return {
        _BS_C.pdRequestManufacturerInfoSop: [filter_tx_manufacturer_info,   filter_rx_manufacturer_info,    decode_pdRequestManufacturerInfoSop],
        _BS_C.pdRequestBatteryCapabilities: [filter_tx_battery_cap,         filter_rx_battery_cap,          decode_pdRequestBatteryCapabilities],
        _BS_C.pdRequestBatteryStatus:       [filter_tx_battery_status,      filter_rx_battery_status,       decode_pdRequestBatteryStatus],
    }


#This function returns when it has timed out or PD traffic has been
#silent for defined about of time.
#At the beginning of PD conversations some devices can be really chatty (Apple). 
#During this time many devices will ignore/drop PD packets. 
def wait_for_silence(logger, silence_time_seconds=2, max_time_to_wait_seconds=10):
    start_time = time.time()  # Record the start time
    last_time = start_time
    silence_time = 0

    result = logger.getPacket()
    while True:
        current_time = time.time() 
        elapsed_time = current_time - start_time 

        if result.error == Result.NOT_READY:
            silence_time += current_time - last_time
            if silence_time >= silence_time_seconds:
                return True

            time.sleep(.1)
        elif result.error == Result.NO_ERROR:
            silence_time = 0
        else:
            print("Unknown Error: %d", result.error)

        if elapsed_time >= max_time_to_wait_seconds: 
            print("Maximum time reached(silence)")
            break

        result = logger.getPacket()
        last_time = current_time

    return False



#Generic pd packet filter function which requires a 
#a transmit and receive condition for success.
#This function ensures the desired packet is sent AND received.
#Note: In this implementation all other packets will be trashed.
def pd_packet_filter(logger, sop, func_tx, func_rx, max_time_seconds=5):
    start_time = time.time()  # Record the start time
    saw_packet_leave = False #Latch for indicating the tx function has completed successfully. 
    result = logger.getPacket()
    while True:
        current_time = time.time() 
        elapsed_time = current_time - start_time  

        if result.error == Result.NO_ERROR:
            if saw_packet_leave:
                if func_rx(result.value, sop):
                    return result
            else:
                if func_tx(result.value, sop):
                    saw_packet_leave = True #Set latch that indicate we saw our packet get sent.
        else:
            time.sleep(.1)

        if elapsed_time >= max_time_seconds: 
            print("Maximum time reached (pd_packet_filter)")
            break

        result = logger.getPacket()

    return None




#////////////////////////////////////////////////////
#pdRequestBatteryCapabilities(14) - Battery Capabilities PD Message filters
#////////////////////////////////////////////////////
def filter_tx_battery_cap(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 1 = TX - Case for us sending the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_GET_BATTERY_CAP = 3
        if packet.direction == 1    and \
            header.messageType == 3 and \
            header.extended:

            return True

    return False


def filter_rx_battery_cap(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))

        #Direction 2 = RX - Case for receiving the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_BATTERY_CAPABILITIES = 5
        if packet.direction == 2    and \
            header.messageType == 5 and \
            header.extended:

            return True

    return False


def decode_pdRequestBatteryCapabilities(result):
    if len(result.payload) < 13:
        return Result(Result.SIZE_ERROR, dict())

    vendor_id = get_two_bytes_from_buffer(result.payload, 4)
    product_id = get_two_bytes_from_buffer(result.payload, 6)
    design_cap = get_two_bytes_from_buffer(result.payload, 8)
    last_full_cap = get_two_bytes_from_buffer(result.payload, 10)
    battery_type = result.payload[12]

    d = dict()
    d["Vendor ID"] = ("0x%04X" % (vendor_id))
    d["Product ID"] = ("0x%04X" % (product_id))
    d["Design Cap (WH)"] = (design_cap * 100) / 1000
    d["Last Full Cap (WH)"] = (last_full_cap * 100) / 1000
    d["Invalid"] = battery_type & 0x01
    d["Reserved"] = (battery_type & 0xFE) >> 1
    return Result(Result.NO_ERROR, d)

#////////////////////////////////////////////////////




#////////////////////////////////////////////////////
#pdRequestBatteryStatus(15) - Battery Status Info PD Message filters
#////////////////////////////////////////////////////
def filter_tx_battery_status(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 1 = TX - Case for us sending the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_GET_BATTERY_STATUS = 4
        if packet.direction == 1    and \
            header.messageType == 4 and \
            header.extended:
            return True

    return False


def filter_rx_battery_status(packet, sop):
    if len(packet.payload) == 6 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))

        #Direction 2 = RX - Case for receiving the pd packet
        # kPD_MESSAGE_DATA_TYPE_BATTERY_STATUS = 5
        if packet.direction == 2        and \
            header.messageType == 5     and \
            header.extended == False    and \
            header.numOfDataObjs == 1:
            return True

    return False


def decode_pdRequestBatteryStatus(result):
    if len(result.payload) == 6:
        data_object = get_four_bytes_from_buffer(result.payload, 2)
        bsdo = rule_to_battery_status_data_object(data_object)

        d = dict()

        d["State of Charge (WH)"] = None if bsdo.battery_pc == 0xFFFF else (bsdo.battery_pc * 100) / 1000
        d["Reserved 1"] = bsdo.reserved
        d["Invalid Battery Reference"] = bsdo.battery_info_bit_fields.invalid_battery_ref
        d["Battery Present"] = bsdo.battery_info_bit_fields.battery_is_present

        # Note:
        # typedef enum {
        #     kPDBatteryChargeCharging     = 0,
        #     kPDBatteryChargeDischarging  = 1,
        #     kPDBatteryChargeIdle         = 2,
        #     kPDBatteryChargeReserved     = 3,
        # } PD_AppBatteryState_t;
        d["Charging Status"] = bsdo.battery_info_bit_fields.battery_charging_status

        d["Reserved 2"] = bsdo.battery_info_bit_fields.reserved

        return Result(Result.NO_ERROR, d)

    return Result(Result.SIZE_ERROR, dict())
#////////////////////////////////////////////////////




#////////////////////////////////////////////////////
#pdRequestManufacturerInfoSop(16) - Manufacturer Info PD Message filters
#////////////////////////////////////////////////////
def filter_tx_manufacturer_info(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))
        
        # Direction 1 = TX - Case for us sending the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_GET_MANUFACTURER_INFO = 6
        if packet.direction == 1    and \
            header.messageType == 6 and \
            header.extended:

            return True

    return False


def filter_rx_manufacturer_info(packet, sop):
    if len(packet.payload) >= 2 and packet.sop == sop:
        header = PDMessageHeader(get_two_bytes_from_buffer(packet.payload, 0))

        #Direction 2 = RX - Case for receiving the pd packet
        # kPD_MESSAGE_EXTENDED_TYPE_MANUFACTURER_INFO = 7
        if packet.direction == 2    and \
            header.messageType == 7 and \
            header.extended:

            return True

    return False


def decode_pdRequestManufacturerInfoSop(result):
    #2x bytes of header, 2x bytes extended header, 2x bytes VendorID, 2x bytes Product ID, string
    if len(result.payload) >= 8:
        vendor_id = get_two_bytes_from_buffer(result.payload, 4)
        product_id = get_two_bytes_from_buffer(result.payload, 6)
        mfg_string = ''.join(chr(x) for x in result.payload[8:len(result.payload)])

        d = dict()
        d["Vendor ID"] = ("0x%04X" % (vendor_id))
        d["Product ID"] = ("0x%04X" % (product_id))
        d["Mfg String"] = mfg_string
        return Result(Result.NO_ERROR, d)

    return Result(Result.SIZE_ERROR, dict())
#////////////////////////////////////////////////////








#////////////////////////////////////////////////////
#Note: Filtering is minimum here because dfu commands travel 
#through SOP'Debug and SOP''Debug commands. These are not common.
#////////////////////////////////////////////////////
def filter_tx_apple_dfu_vdm(packet, sop):
    #Direction 1 = TX - Case for us sending the pd packet
    if packet.sop == sop and packet.direction == 1:
        return True #Set latch that indicate we saw our packet get sent.
    return False


def filter_rx_apple_dfu_vdm(packet, sop):
    #Direction 2 = RX - Case for receiving the pd packet
    if packet.sop == sop and packet.direction == 2:
        return True

    return False
#////////////////////////////////////////////////////
