# ###############################################################################
#
# Copyright (C) 2024 Arm Limited (or its affiliates). All rights reserved.
#
# Zephyr RTOS OS Awareness Extension for ARM Development Studio
#
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

from osapi import *
from globs import *

# Object format codes
FORMAT_NONE            = 0
FORMAT_ADDRESS         = 1
FORMAT_ADDRESS_LIN     = 2
FORMAT_ADDRESS_STR     = 3
FORMAT_HEX             = 4
FORMAT_NUMBER          = 5
FORMAT_NUMBER_STR      = 6
FORMAT_STRING          = 7
FORMAT_STRING_PTR      = 8
FORMAT_ARRAY           = 9
FORMAT_LOCATION        = 10
FORMAT_THREAD_LIST     = 11
FORMAT_THREAD_STATE    = 12
FORMAT_YES_NO          = 13
FORMAT_PRIORITY        = 14

# Thread state names mapping
THREAD_STATE_NAMES = \
{
    0x00 : "READY",
    0x01 : "DUMMY",
    0x02 : "PENDING",
    0x04 : "PRESTART",
    0x08 : "DEAD",
    0x10 : "SUSPENDED",
    0x20 : "ABORTING",
    0x80 : "QUEUED"
}

# Register map names / FPU status
REGMAPNAMES = \
{
    REG_MAP_V8MBASIC  : "Disabled",
    REG_MAP_V8MVFP    : "Enabled",
    REG_MAP_V7MBASIC  : "Not Present",
    REG_MAP_V7MVFP    : "Enabled",
    REG_MAP_V7ABASIC  : "Disabled",
    REG_MAP_V7AVFP    : "Enabled"
}

# Format expression
def formatExpr( expr, exprType, debugSession, defVal="N/A" ):

    # Assign default return value (used when no value can be found)
    val = defVal

    if exprType == FORMAT_NONE:
        val = expr

    elif exprType == FORMAT_ADDRESS:
        val = expr.readAsAddress( )

    elif exprType == FORMAT_ADDRESS_LIN:
        val = expr.readAsAddress( ).getLinearAddress( )

    elif exprType == FORMAT_ADDRESS_STR:
        val = expr.readAsAddress( ).toString( )

    elif exprType == FORMAT_HEX:
        val = hex( expr.readAsNumber( ) )

    elif exprType == FORMAT_NUMBER:
        val = expr.readAsNumber( )

    elif exprType == FORMAT_NUMBER_STR:
        val = str( expr.readAsNumber( ) )

    elif exprType == FORMAT_STRING:
        val = expr.readAsNullTerminatedString( )

    elif exprType == FORMAT_STRING_PTR:
        addr = expr.readAsNumber( )
        if addr != 0:
            val = debugSession.evaluateExpression( "(char*)" + str( addr ) ).readAsNullTerminatedString( )

    elif exprType == FORMAT_ARRAY:
            val = debugSession.evaluateExpression( expr ).getArrayElements( )

    elif exprType == FORMAT_LOCATION:
        val = hex( expr.getLocationAddress( ).getLinearAddress( ) )

    elif exprType == FORMAT_THREAD_STATE:
        state = expr.readAsNumber( )
        val = getThreadStateName( state )

    elif exprType == FORMAT_YES_NO:
        if expr.readAsNumber( ) == 0:
            val = "No"
        else:
            val = "Yes"

    elif exprType == FORMAT_PRIORITY:
        prio = expr.readAsNumber( )
        # Zephyr uses signed priority, negative = cooperative, positive = preemptive
        if prio < 0:
            val = str(prio) + " (Cooperative)"
        else:
            val = str(prio) + " (Preemptive)"

    else:
        pass

    # Formatted value
    return val

# Get value of member
def getMemberValue( members, member, exprType, debugSession, defVal="N/A" ):

    # Assign default return value (used when no value can be found)
    val = defVal

    # Check if we have a valid member
    if member in members:

        # Get member expression
        expr = members[ member ]

        # Get value
        val = formatExpr( expr, exprType, debugSession )

    # Member value
    return val

# Get member name from list of members
def getMemberName( member, members ):

    # Default value
    name = ""

    # Found value?
    if member in members:
        name = member

    # Member name
    return name

# Get thread state name from bitmask
def getThreadStateName( stateBitmask ):
    # Check each bit and return the first matching state
    if stateBitmask == 0:
        return "READY"
    if stateBitmask & THREAD_STATE_DEAD:
        return "DEAD"
    if stateBitmask & THREAD_STATE_SUSPENDED:
        return "SUSPENDED"
    if stateBitmask & THREAD_STATE_ABORTING:
        return "ABORTING"
    if stateBitmask & THREAD_STATE_PENDING:
        return "PENDING"
    if stateBitmask & THREAD_STATE_PRESTART:
        return "PRESTART"
    if stateBitmask & THREAD_STATE_QUEUED:
        return "QUEUED"
    if stateBitmask & THREAD_STATE_DUMMY:
        return "DUMMY"
    return "UNKNOWN"

def addressExprsToLong( expr ):
    addr = expr.getLocationAddress()
    return addr.getLinearAddress()

# Read threads from linked list starting at head
def readThreadItems( debugSession, headPtr ):

    # Initialize list
    threadItems = [ ]

    # Make sure head is valid
    if headPtr.readAsNumber( ) == 0:
        return threadItems

    # Get first thread
    currentThread = headPtr.dereferencePointer( )
    visitedAddresses = set()

    # Process linked list
    while True:
        # Get thread address
        threadAddr = currentThread.getLocationAddress( ).getLinearAddress( )

        # Check for circular reference
        if threadAddr in visitedAddresses or threadAddr == 0:
            break

        visitedAddresses.add( threadAddr )

        # Add thread to list
        threadItems.append( currentThread )

        # Get members
        members = currentThread.getStructureMembers( )

        # Get next thread pointer
        if THREAD_NEXT_THREAD in members:
            nextPtr = members[ THREAD_NEXT_THREAD ]
            if nextPtr.readAsNumber( ) == 0:
                break
            currentThread = nextPtr.dereferencePointer( K_THREAD + "*" )
        else:
            break

    # List of threads
    return threadItems

# Read threads from dlist (doubly-linked list)
def readDListItems( debugSession, dlistExpr, itemType, nodeOffset=0 ):

    # Initialize list
    items = [ ]

    # Get dlist structure members
    dlistMembers = dlistExpr.getStructureMembers( )

    # Get head pointer
    if DLIST_HEAD in dlistMembers:
        headPtr = dlistMembers[ DLIST_HEAD ]
        if headPtr.readAsNumber( ) == 0:
            return items

        # Walk the list
        currentNode = headPtr
        visitedAddresses = set()

        while True:
            nodeAddr = currentNode.readAsNumber( )
            if nodeAddr == 0 or nodeAddr in visitedAddresses:
                break

            visitedAddresses.add( nodeAddr )

            # Calculate item address from node address (node is embedded in item)
            itemAddr = nodeAddr - nodeOffset
            item = debugSession.evaluateExpression( "(" + itemType + "*)" + str( itemAddr ) ).dereferencePointer( )
            items.append( item )

            # Get next node
            nodeMembers = currentNode.dereferencePointer( ).getStructureMembers( )
            if "next" in nodeMembers:
                currentNode = nodeMembers[ "next" ]
            else:
                break

    return items

# Get thread names from list
def getThreadNamesFromList( threadList, debugSession ):

    # Default text when list empty
    threadNames = "N/A"

    # Get thread names
    for thread in threadList:
        members = thread.getStructureMembers( )
        if THREAD_NAME in members:
            name = members[ THREAD_NAME ].readAsNullTerminatedString( )
            if threadNames == "N/A":
                threadNames = name
            else:
                threadNames = threadNames + ", " + name

    return threadNames

# Get register map name based on architecture
def getRegMapName( stackPointer, debugSession ):

    # default text
    regMapName = REG_MAP_V7MBASIC

    # Get processor architecture
    archName = debugSession.getTargetInformation( ).getArchitecture( ).getName( )

    # Cortex-M
    if archName == "ARMv7M" or archName == "ARMv6M":
        # Check for FPU
        hasFPU = debugSession.symbolExists( "$FPSCR" ) == 1
        if hasFPU:
            CPACR = debugSession.evaluateExpression( "$CPACR" ).readAsNumber( )
            if CPACR & ( 0xF << 20 ):
                regMapName = REG_MAP_V7MVFP
            else:
                regMapName = REG_MAP_V7MBASIC
        else:
            regMapName = REG_MAP_V7MBASIC

    elif archName == "ARMv8M":
        # Check for FPU in ARMv8-M
        hasFPU = debugSession.symbolExists( "$FPSCR" ) == 1
        if hasFPU:
            CPACR = debugSession.evaluateExpression( "$CPACR" ).readAsNumber( )
            if CPACR & ( 0xF << 20 ):
                regMapName = REG_MAP_V8MVFP
            else:
                regMapName = REG_MAP_V8MBASIC
        else:
            regMapName = REG_MAP_V8MBASIC

    # Cortex-A/R
    elif archName == "ARMv7A" or archName == "ARMv7R":
        # FPU detection for Cortex-A/R
        hasFPU = debugSession.symbolExists( "$FPSCR" ) == 1
        if hasFPU:
            regMapName = REG_MAP_V7AVFP
        else:
            regMapName = REG_MAP_V7ABASIC

    # Register map name
    return regMapName

def make_reg_list(offset, size, *reg_names):
    result = [("%s" % reg_names[x], long(offset + x*size)) for x in xrange(len(reg_names))]
    return result

def make_reg_range(offset, size, prefix, start, count):
    result = [("%s%d" % (prefix, start+x), long(offset + x*size)) for x in xrange(0, count)]
    return result

# Get FPU status text from map name
def getFPUStatusText( regMapName ):
    if regMapName in REGMAPNAMES:
        return REGMAPNAMES[ regMapName ]
    return "Unknown"

# Create a cell
def createCell( cells, members, itemName, itemFormat, cellType, debugSession ):

    if cellType == ADDRESS:
        cells.append( createAddressCell( getMemberValue( members, itemName, itemFormat, debugSession ) ) )
    elif cellType == DECIMAL:
        cells.append( createNumberCell( getMemberValue( members, itemName, itemFormat, debugSession ) ) )
    else:
        cells.append( createTextCell( getMemberValue( members, itemName, itemFormat, debugSession ) ) )

# Check for stack overflow using Zephyr's stack sentinel
def checkStackOverflow( debugSession, threadMembers ):
    try:
        if THREAD_STACK_INFO in threadMembers:
            stackInfo = threadMembers[ THREAD_STACK_INFO ].getStructureMembers( )
            if STACK_INFO_START in stackInfo:
                stackStart = stackInfo[ STACK_INFO_START ].readAsNumber( )
                # Check for stack sentinel at the start of stack
                sentinel = debugSession.evaluateExpression( "*(unsigned int*)" + str( stackStart ) ).readAsNumber( )
                if sentinel != STACK_SENTINEL_MAGIC:
                    return True
        return False
    except:
        return False

# Read wait queue and get waiting threads
def readWaitQueue( debugSession, waitQExpr ):
    waiters = [ ]
    try:
        members = waitQExpr.getStructureMembers( )
        if WAITQ_WAITQ in members:
            dlist = members[ WAITQ_WAITQ ]
            # This is a dlist, read items from it
            waiters = readDListItems( debugSession, dlist, K_THREAD )
    except:
        pass
    return waiters
