# ###############################################################################
#
# Copyright (C) 2024 Arm Limited (or its affiliates). All rights reserved.
#
# Zephyr RTOS OS Awareness Extension for ARM Development Studio
#
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

from osapi import *
from utils import *
from globs import *

# Threads class
class Threads( Table ):

    # Class ID
    ID = "threads"

    # Column definitions - Thread name
    cols = \
    [
        [ THREAD_NAME,               FORMAT_STRING,      TEXT    ]
    ]

    # Column definitions 2 - Thread details from base structure
    cols2 = \
    [
        [ THREAD_BASE_PRIO,          FORMAT_NUMBER,      DECIMAL ],
        [ THREAD_BASE_THREAD_STATE,  FORMAT_THREAD_STATE, TEXT   ],
    ]

    # Column definitions 3 - Stack information
    cols3 = \
    [
        [ STACK_INFO_START,          FORMAT_ADDRESS,     ADDRESS ],
        [ STACK_INFO_SIZE,           FORMAT_NUMBER,      DECIMAL ],
    ]

    # Class constructor
    def __init__( self ):

        # Class ID
        cid = self.ID

        # Create primary field
        fields = [ createPrimaryField( cid, K_THREAD, TEXT ) ]

        # Add thread name
        for col in self.cols:
            fields.append( createField( cid, col[ 0 ], col[ 2 ] ) )

        # Add state field (determined from thread state)
        fields.append( createField( cid, "state", TEXT ) )

        # Add priority field
        fields.append( createField( cid, "priority", DECIMAL ) )

        # Add FPU field to indicate if the thread is using the FPU
        fields.append( createField( cid, "usingFPU", TEXT ) )

        # Add stack overflow field
        fields.append( createField( cid, "stackOverflow", TEXT ) )

        # Add stack start and size
        fields.append( createField( cid, "stackStart", ADDRESS ) )
        fields.append( createField( cid, "stackSize", DECIMAL ) )

        # Add callee-saved registers (for debugging)
        fields.append( createField( cid, "psp", ADDRESS ) )

        # Create table
        Table.__init__( self, cid, fields )

    # Read thread details
    def readThread( self, cid, members, state, debugSession ):

        # Populate primary field
        cells = [ createTextCell( cid ) ]

        # Populate thread name
        for col in self.cols:
            cells.append( createTextCell( getMemberValue( members, col[ 0 ], col[ 1 ], debugSession ) ) )

        # Set thread state
        cells.append( createTextCell( state ) )

        # Get priority from base structure
        priority = "N/A"
        if THREAD_BASE in members:
            baseMembers = members[ THREAD_BASE ].getStructureMembers( )
            if THREAD_BASE_PRIO in baseMembers:
                prio = baseMembers[ THREAD_BASE_PRIO ].readAsNumber( )
                # Handle signed priority (cooperative threads have negative priority)
                if prio > 127:
                    prio = prio - 256
                priority = str(prio)
        cells.append( createNumberCell( priority ) )

        # Get stack pointer and determine FPU usage
        stackPointer = None
        if THREAD_CALLEE_SAVED in members:
            calleeSaved = members[ THREAD_CALLEE_SAVED ].getStructureMembers( )
            if CALLEE_SAVED_PSP in calleeSaved:
                stackPointer = calleeSaved[ CALLEE_SAVED_PSP ].readAsAddress( )

        if stackPointer:
            regMapName = getRegMapName( stackPointer, debugSession )
            cells.append( createTextCell( getFPUStatusText( regMapName ) ) )
        else:
            cells.append( createTextCell( "N/A" ) )

        # Check for stack overflow
        stackOverflow = checkStackOverflow( debugSession, members )
        if stackOverflow:
            cells.append( createTextCell( "OVERFLOW!" ) )
        else:
            cells.append( createTextCell( "OK" ) )

        # Get stack info
        stackStart = "N/A"
        stackSize = "N/A"
        if THREAD_STACK_INFO in members:
            stackInfo = members[ THREAD_STACK_INFO ].getStructureMembers( )
            if STACK_INFO_START in stackInfo:
                stackStart = hex( stackInfo[ STACK_INFO_START ].readAsNumber( ) )
            if STACK_INFO_SIZE in stackInfo:
                stackSize = str( stackInfo[ STACK_INFO_SIZE ].readAsNumber( ) )

        cells.append( createTextCell( stackStart ) )
        cells.append( createNumberCell( stackSize ) )

        # Get PSP value
        pspValue = "N/A"
        if stackPointer:
            pspValue = hex( stackPointer.getLinearAddress( ) )
        cells.append( createAddressCell( pspValue ) )

        # Populated record
        return self.createRecord( cells )

    # Add thread to record
    def addThread( self, thread, threadState, records, debugSession ):

        # Get thread address and use it as the thread ID
        pThreadId = hex( thread.getLocationAddress( ).getLinearAddress( ) )

        # Get the thread structure members
        pThreadMembers = thread.getStructureMembers( )

        # Create thread record
        records.append( self.readThread( pThreadId, pThreadMembers, threadState, debugSession ) )


    # Get all threads
    def getRecords( self, debugSession ):

        # Blank records
        records = [ ]

        # Make sure OS is initialized
        if debugSession.symbolExists( Z_KERNEL ):

            # Get kernel structure
            kernelExpr = debugSession.evaluateExpression( Z_KERNEL )
            kernelMembers = kernelExpr.getStructureMembers( )

            # Get current thread (running)
            currentThreadAddr = 0
            if Z_KERNEL_CURRENT in kernelMembers:
                currentThreadPtr = kernelMembers[ Z_KERNEL_CURRENT ]
                currentThreadAddr = currentThreadPtr.readAsNumber( )

                if currentThreadAddr:
                    currentThread = currentThreadPtr.dereferencePointer( )
                    self.addThread( currentThread, "RUNNING", records, debugSession )

            # Get all threads from linked list
            if Z_KERNEL_THREADS in kernelMembers:
                threadsPtr = kernelMembers[ Z_KERNEL_THREADS ]

                if threadsPtr.readAsNumber( ):
                    # Read all threads
                    threadItems = readThreadItems( debugSession, threadsPtr )

                    for thread in threadItems:
                        threadAddr = thread.getLocationAddress( ).getLinearAddress( )

                        # Skip current thread (already added)
                        if threadAddr != currentThreadAddr:
                            # Get thread state from base structure
                            members = thread.getStructureMembers( )
                            state = "UNKNOWN"
                            if THREAD_BASE in members:
                                baseMembers = members[ THREAD_BASE ].getStructureMembers( )
                                if THREAD_BASE_THREAD_STATE in baseMembers:
                                    stateBitmask = baseMembers[ THREAD_BASE_THREAD_STATE ].readAsNumber( )
                                    state = getThreadStateName( stateBitmask )

                            self.addThread( thread, state, records, debugSession )

        # Thread records
        return records
