# ###############################################################################
#
# Copyright (C) 2024 Arm Limited (or its affiliates). All rights reserved.
#
# Zephyr RTOS OS Awareness Extension for ARM Development Studio
#
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

from osapi import *
from globs import *
from utils import *
from itertools import *


# ARM Cortex-M Basic Registers Map (no FPU)
# Zephyr stores callee-saved registers on context switch
M_CLASS_BASIC_REGISTERS_MAP = \
{
       "R4":   0L,  # v1
       "R5":   4L,  # v2
       "R6":   8L,  # v3
       "R7":  12L,  # v4
       "R8":  16L,  # v5
       "R9":  20L,  # v6
      "R10":  24L,  # v7
      "R11":  28L,  # v8
       "SP":  32L,  # PSP (Process Stack Pointer)
       "LR":  36L,  # Link Register (EXC_RETURN)
# Exception frame (pushed by hardware)
       "R0":  40L,
       "R1":  44L,
       "R2":  48L,
       "R3":  52L,
      "R12":  56L,
       "PC":  64L,
     "xPSR":  68L,
     "CPSR":  68L,
}

# ARM Cortex-M Extended Registers Map (with FPU, but task not using FPU)
M_CLASS_EXTENDED_REGISTERS_MAP = \
{
       "R4":   0L,
       "R5":   4L,
       "R6":   8L,
       "R7":  12L,
       "R8":  16L,
       "R9":  20L,
      "R10":  24L,
      "R11":  28L,
       "SP":  32L,
       "LR":  36L,
       "R0":  40L,
       "R1":  44L,
       "R2":  48L,
       "R3":  52L,
      "R12":  56L,
       "PC":  64L,
     "xPSR":  68L,
     "CPSR":  68L,
}

# ARM Cortex-M FPU Registers Map (with FPU, task using FPU)
M_CLASS_FPU_REGISTERS_MAP = \
{
       "R4":   0L,
       "R5":   4L,
       "R6":   8L,
       "R7":  12L,
       "R8":  16L,
       "R9":  20L,
      "R10":  24L,
      "R11":  28L,
       "SP":  32L,
       "LR":  36L,
      "S16":  40L,
      "S17":  44L,
      "S18":  48L,
      "S19":  52L,
      "S20":  56L,
      "S21":  60L,
      "S22":  64L,
      "S23":  68L,
      "S24":  72L,
      "S25":  76L,
      "S26":  80L,
      "S27":  84L,
      "S28":  88L,
      "S29":  92L,
      "S30":  96L,
      "S31": 100L,
       "D8":  40L,
       "D9":  48L,
      "D10":  56L,
      "D11":  64L,
      "D12":  72L,
      "D13":  80L,
      "D14":  88L,
      "D15":  96L,
       "R0": 104L,
       "R1": 108L,
       "R2": 112L,
       "R3": 116L,
      "R12": 120L,
       "PC": 128L,
     "xPSR": 132L,
     "CPSR": 132L,
       "S0": 136L,
       "S1": 140L,
       "S2": 144L,
       "S3": 148L,
       "S4": 152L,
       "S5": 156L,
       "S6": 160L,
       "S7": 164L,
       "S8": 168L,
       "S9": 172L,
      "S10": 176L,
      "S11": 180L,
      "S12": 184L,
      "S13": 188L,
      "S14": 192L,
      "S15": 196L,
       "D0": 136L,
       "D1": 144L,
       "D2": 152L,
       "D3": 160L,
       "D4": 168L,
       "D5": 176L,
       "D6": 184L,
       "D7": 192L,
    "FPSCR": 200L,
}

# ARM Cortex-A/R Basic Registers Map
A_CLASS_BASIC_REGISTERS_MAP = \
{
       "R0":   0L,
       "R1":   4L,
       "R2":   8L,
       "R3":  12L,
       "R4":  16L,
       "R5":  20L,
       "R6":  24L,
       "R7":  28L,
       "R8":  32L,
       "R9":  36L,
      "R10":  40L,
      "R11":  44L,
      "R12":  48L,
       "LR":  52L,
       "PC":  56L,
     "xPSR":  60L,
     "CPSR":  60L,
       "SP":  64L
}

# ARM Cortex-A/R FPU Registers Map
A_CLASS_FPU_REGISTERS_MAP = dict(chain(
    [("FPSCR",   4L)],
    # Thirty-two 64-bit double-word registers, D0-D31. (FPU register bank)
    make_reg_range(8L, 8L, "D", 0, 32),
    # Thirty-two 32-bit single-word registers, S0-S31. (FPU register bank)
    make_reg_range(8L, 4L, "S", 0, 32),
     # Basic registers
    make_reg_range(268L, 4L, "R", 0, 13),
    make_reg_list(320, 4, "LR", "PC", "CPSR", "SP")
    ))

class ContextsProvider( ExecutionContextsProvider ):

    # Get context of current executing thread
    def getCurrentOSContext( self, debugger ) :

        # Make sure expression is valid
        if debugger.symbolExists( Z_KERNEL ) :

            # Get kernel structure
            kernelExpr = debugger.evaluateExpression( Z_KERNEL )
            kernelMembers = kernelExpr.getStructureMembers( )

            # Get pointer to current thread
            if Z_KERNEL_CURRENT in kernelMembers:
                currentThreadPtr = kernelMembers[ Z_KERNEL_CURRENT ]

                # Make sure pointer valid
                if currentThreadPtr.readAsNumber( ) :

                    # Get thread structure
                    currentThread = currentThreadPtr.dereferencePointer( )

                    # Create context
                    return self.createContextFromThread( debugger, currentThread, None )

    # Get context of all created threads
    def getAllOSContexts( self, debugger ):

        # List is empty
        contexts = [ ]

        # Check expressions are valid
        if debugger.symbolExists( Z_KERNEL ):

            # Get kernel structure
            kernelExpr = debugger.evaluateExpression( Z_KERNEL )
            kernelMembers = kernelExpr.getStructureMembers( )

            # Get current thread (running)
            currentThreadPtr = None
            currentThreadAddr = 0
            if Z_KERNEL_CURRENT in kernelMembers:
                currentThreadPtr = kernelMembers[ Z_KERNEL_CURRENT ]
                currentThreadAddr = currentThreadPtr.readAsNumber( )

                if currentThreadAddr:
                    currentThread = currentThreadPtr.dereferencePointer( )
                    contexts.append( self.createContextFromThread( debugger, currentThread, None ) )

            # Get all threads from the linked list
            if Z_KERNEL_THREADS in kernelMembers:
                threadsPtr = kernelMembers[ Z_KERNEL_THREADS ]

                if threadsPtr.readAsNumber( ):
                    # Read all threads
                    threadItems = readThreadItems( debugger, threadsPtr )

                    for thread in threadItems:
                        threadAddr = thread.getLocationAddress( ).getLinearAddress( )
                        # Skip the current thread (already added)
                        if threadAddr != currentThreadAddr:
                            contexts.append( self.createContextFromThread( debugger, thread, "THREADS" ) )

        # All thread contexts
        return contexts

    # Get register contents saved on thread stack
    def getOSContextSavedRegister( self, debugger, context, name ):

        # Check if requested register is available
        offset = context.getAdditionalData( )[ "register_map" ].get( name, None )
        if offset == None:
            return None

        # Get stack pointer
        base = context.getAdditionalData( )[ "stack_ptr" ]

        # Get locations of requested register
        base = base.addOffset( offset )

        # Are we reading the stack pointer?
        if name == "SP":
            return debugger.evaluateExpression( "(long)" + str( base ) )
        else:
            return debugger.evaluateExpression( "(long*)" + str( base ) )

    # Create context from thread structure
    def createContextFromThread( self, debugger, thread, listName ):

        # Get structure members of thread
        members = thread.getStructureMembers( )

        # Get thread id (use thread address)
        threadId = thread.getLocationAddress( ).getLinearAddress( )

        # Get thread name
        name = "Unknown"
        if THREAD_NAME in members:
            namePtr = members[ THREAD_NAME ]
            if namePtr.readAsNumber( ) != 0:
                name = namePtr.readAsNullTerminatedString( )

        # Get thread state from base structure
        state = "UNKNOWN"
        if THREAD_BASE in members:
            baseMembers = members[ THREAD_BASE ].getStructureMembers( )
            if THREAD_BASE_THREAD_STATE in baseMembers:
                stateBitmask = baseMembers[ THREAD_BASE_THREAD_STATE ].readAsNumber( )
                state = getThreadStateName( stateBitmask )

        # If no list name, this is the current (running) thread
        if not listName:
            state = "RUNNING"

        # Create thread context
        context = OSContext( threadId, name, state )

        # Get stack pointer from callee_saved structure
        stackPointer = None
        if THREAD_CALLEE_SAVED in members:
            calleeSaved = members[ THREAD_CALLEE_SAVED ].getStructureMembers( )
            if CALLEE_SAVED_PSP in calleeSaved:
                stackPointer = calleeSaved[ CALLEE_SAVED_PSP ].readAsAddress( )

        # If no PSP found, try to get from switch_handle
        if stackPointer is None and THREAD_SWITCH_HANDLE in members:
            switchHandle = members[ THREAD_SWITCH_HANDLE ]
            if switchHandle.readAsNumber( ) != 0:
                stackPointer = switchHandle.readAsAddress()

        # Save stack pointer
        if stackPointer:
            context.getAdditionalData( )[ "stack_ptr" ] = stackPointer
        else:
            # Use a default value
            context.getAdditionalData( )[ "stack_ptr" ] = debugger.evaluateExpression( "0" ).readAsAddress()

        # Get register map name
        regMapName = getRegMapName( stackPointer if stackPointer else debugger.evaluateExpression("0").readAsAddress(), debugger )

        # Save register map
        context.getAdditionalData( )[ "register_map" ] = self.getRegisterMap( regMapName )

        # Complete thread context
        return context

    # Get register map
    def getRegisterMap( self, value ):
        if value == REG_MAP_V7AVFP:
            return A_CLASS_FPU_REGISTERS_MAP
        elif value == REG_MAP_V7ABASIC:
            return A_CLASS_BASIC_REGISTERS_MAP
        elif value == REG_MAP_V7MVFP or value == REG_MAP_V8MVFP:
            return M_CLASS_FPU_REGISTERS_MAP
        elif value == REG_MAP_V7MBASIC or value == REG_MAP_V8MBASIC:
            return M_CLASS_BASIC_REGISTERS_MAP
        else:
            return M_CLASS_BASIC_REGISTERS_MAP


    def getNonGlobalRegisterNames(self):
        result = set()
        result.update(M_CLASS_BASIC_REGISTERS_MAP.keys())
        result.update(M_CLASS_EXTENDED_REGISTERS_MAP.keys())
        result.update(M_CLASS_FPU_REGISTERS_MAP.keys())
        result.update(A_CLASS_BASIC_REGISTERS_MAP.keys())
        result.update(A_CLASS_FPU_REGISTERS_MAP.keys())
        return result
