Simple Deobfuscation of Code Transformation


Recently I had to deal with an annoying piece of malware. It used  compression and an obfuscation technique called code transformation.  The packer is known as Win32.Krap, trust me the name is justified. Most of the packer code can be by passed by setting hardware execute breakpoints on the first dword of the allocated memory heaps returned by VirtualAlloc.  Three heaps later, and we will have the above code block. Besides for pretty graphs code transformation makes the assembly look ugly. The transformation of the assembly is simple. For every four or five instructions we will add a JMP to the next instruction.

Example assembly of code transformation
Trying to view this code in IDA is very painful due to all the added JMP instructions along with the junk data. A good example can be seen above. If we want to view the un-obfuscated code all we have to do is view the code minus the JMPs.This sounds kind of simple but I have to admit this was a good learning experience. The algorithm is similar to recursive traversal disassembler except we aren't disassembling anything, we are skipping the JMPs and getting the instruction and address of where the JMP goes. A good example of the algorithm can be found here PDF.   What would the code look like without the JMPs?


Below is my code that I used to create the viewer and remove the JMPs. Word of caution. Use at your own risk. Odds are there are some bugs, bad logic or unused variables. Hopefully this will be a helpful POC for others. If you have any tips or ideas please leave a comment or shoot me an email.


# Python code to remove JMPs from obfuscated code in IDA
# created by alexander dot hanel at gmail dot come
# # Note you will need to have your cursor at what is the start
# of the function or at least in the path.

from idaapi import * 
import idautils
import idc
import sys 

class JMPJMP:
    def __init__(self):
        self.ea = ScreenEA()
        self.errorStatus = 'Good'
        self.funcStartAddr = GetFunctionAttr(self.ea, FUNCATTR_START)
        self.checkFunctionStart()
        self.buffer = []
        self.count = 0
        self.condJmps = ['jo', 'jno', 'jb', 'jnae', 'jc', 'jnb', 'jae', 'jnc', 'jz', \
                                'je', 'jnz', 'jne', 'jbe', 'jna', 'jnbe', 'ja', 'js', 'jns', \
                                'jp', 'jpe', 'jnp', 'jpo', 'jl', 'jnge', 'jnl', 'jge', 'jle', \
                                'jng', 'jnle', 'jg']
        self.condJmpsAddr = set([])
        self.retn = ['retn', 'ret', 'retf']
        self.callAddr = set([])
        self.call = 'call' 
        self.callByte = 0xe8
        self.jmp = 'jmp'
        self.visitedAddr = set([])
        self.target = set([])
    
    def getJmpAddress(self, addr):
        "returns the address the JMP instruction jumps to"
        return GetOperandValue(addr, 0)
        
    def checkFunctionStart(self):
        'checks if the address is valid'
        if self.funcStartAddr is BADADDR:
            print "Could not find find function start address"
            self.errorStatus = 'Bad!'
            
    def checkAddr(self,addr):
        'checks if the address is valid'
        if addr is BADADDR:
            print "Could not find find function start address"
            self.errorStatus = 'Bad!'

    def getNext(self, addr):
        "returns the next address and instructions"
        next = NextHead(addr)
        return next, GetDisasm(next), GetMnem(next), Byte(addr)
        
    def getCur(self, addr):
        "returns address, dissasembly, the mnemoic and byte"
        return addr, GetDisasm(addr), GetMnem(addr), Byte(addr)
        
    def formatLine(self,addr):
        'format the line to mimic IDA layout' 
        return   idaapi.COLSTR(SegName(addr) + ':' + '%08X' % addr, idaapi.SCOLOR_INSN) + '\t' + idaapi.COLSTR(GetDisasm(addr) , idaapi.SCOLOR_INSN)
        
    def printBuffer(self):
        'print the buffer that contains the instructions minus jmps'
        v = idaapi.simplecustviewer_t()
        if v.Create("JMP CleanUp Viewer"):
            for instru in self.buffer:
                v.AddLine(instru)
            v.Show()
        else:
            print "Failed to create viewer, wa waa waaaaa"
        
    def simplify(self, addr, target = list([]) ):
        # check if valid addresss
        if addr in self.visitedAddr:
            return
        else:
            current_addr, current_inst, current_mnem, byte = self.getCur(addr)
            temp = current_addr
            self.buffer.append('__start: %s' % hex(temp))
            while(1):
                self.checkAddr(current_addr)
                if self.errorStatus != 'Good':
                    return    
                if current_mnem in self.jmp:
                    # uncomment if you want to see the jmp instruction in the output 
                    #self.buffer.append(self.formatLine(current_addr))
                    jmpAddr = self.getJmpAddress(current_addr)
                    self.visitedAddr.add(current_addr)    
                    current_addr, current_inst, current_mnem, byte = self.getCur(jmpAddr)
                    continue
                # check for conditonal jmps, if so add to the target aka come back to list
                elif current_mnem in self.condJmps:
                    self.buffer.append(self.formatLine(current_addr))
                    jmpAddr = self.getJmpAddress(current_addr)
                    target.append(jmpAddr)
                # if call, we will need the call address
                elif current_mnem in self.call and byte == self.callByte:
                    self.buffer.append(self.formatLine(current_addr))
                    target.append(GetOperandValue(current_addr,0))                
                else:
                    self.buffer.append(self.formatLine(current_addr))
                
                if current_mnem in self.retn or current_addr in self.visitedAddr:
                    break
                self.visitedAddr.add(current_addr)
                current_addr, current_inst, current_mnem, byte = self.getNext(current_addr)
            
            self.buffer.append('__end: %s ' % hex(temp))
            self.buffer.append('')
            for revisit in target:
                if revisit in self.visitedAddr:
                    continue
                else:
                    self.simplify(revisit, target)
                    
        return

def main():
    simp = JMPJMP()
    simp.simplify(GetFunctionAttr(ScreenEA(), FUNCATTR_START))
    simp.printBuffer()


if __name__ == "__main__":
    main()

4 comments:

  1. Hi Alex,

    First, thanks for sharing this. Second, a short question: "Do you ever sleep?!?!?" ;)

    ReplyDelete
  2. Hi

    Great work.
    It is always good to see implementation material over some theoretical PDFs
    I was wondering if it is possible to get a version of the packer in order to investigate it ?
    from what I see the callgraph itself is pretty huge and removing the garbage is only the 1st stage
    to remove the whole obfuscation

    Shift

    ReplyDelete
    Replies
    1. My email is in the code. Shoot me an email I can share the sample. Cheers.

      Delete
  3. I don't have the sample, so I can't test it at the moment, but I think optimice (http://code.google.com/p/optimice/) will figure out this kind of obfuscation tricks and deobfuscate the code.

    It has some additional deobfuscation techniques, it is definitely worth a try:)

    ReplyDelete