将clang 打印的AST转成svg

发布时间 2023-05-31 13:59:34作者: 浅笑一场

将clang 打印的AST转成svg

start.sh

将clang打印的ast结果放入input.txt

并用toDot.py将input.txt转成output.dot文件

最后使用graphviz通过output.dot生成output.svg

如果C文件出现头文件,可能造成ast过多导致生成的svg无法打开,因此可以手动删除一些input.txt里的ast结果,再运行start.sh

#!/bin/bash
# $1 可以为C文件名,如果没有参数默认使用input.txt文件作为生成的ast结果
# 示例1:	./start.sh test.c
# 示例2:  ./start.sh

if [ $# -eq 1 ];
then
	clang $1 -Xclang -ast-dump -fsyntax-only  -std=c99 | sed "s,\x1B\[[0-9;]*[a-zA-Z],,g" > input.txt
fi 
# 运行toDot.py程序
python ./toDot.py input.txt output.dot 
# 生成svg 
dot -Nfontsize=10 -Tsvg output.dot -o output.svg

toDot.py

#!/usr/bin/python3
import re
import shlex
import sys

#
# argv[1]:输入文件名
# argv[2]:输出文件名

# 中间print的内容都被输出到output
addrMap = {}
gIndex = 0

class AST_Node:
    def __init__(self, index, spelling, address, loc, other):
        self.index = index
        self.spelling = spelling
        self.address = address
        self.other = other
        self.loc = loc
        self.children = []

        self.loc = self.loc.replace("<", "\\<")
        self.loc = self.loc.replace(">", "\\>")

        # for i,x in enumerate(self.other):
        #     if re.search("<(.*?)>",x):
        #         self.other[i]=x.replace(">","\\>")
        #         self.other[i]=x.replace("<","\\<")

        if len(self.other) > 0 and ("line:" in self.other[0] or "col:" in self.other[0]):
            self.loc = self.loc + " " + self.other[0]
            del self.other[0]


def process(line):
    global gIndex
    cnt = 0
    for i in range(len(line)):
        if line[i].isalpha():
            cnt = i
            x = line[i]
            break
    line = line[cnt:]
    split = shlex.split(line)
    spelling = split[0]
    address = split[1]
    loc = ""
    loc_tmp = ""
    if len(split) > 2:
        loc_tmp = re.findall("<.*?>", split[2])
    if len(split) > 2 and len(loc_tmp) != 0:
        loc = loc_tmp[0]
        other = split[3:]
    else:
        other = split[2:]

    other = [x.strip() for x in other if x.strip() != '']
    ans = AST_Node(gIndex, spelling, address, loc, other)
    gIndex = gIndex + 1
    return cnt / 2, ans


index = 0


def take():
    global index
    oldVal, oldNode = allList[index]
    while (index + 1) < len(allList):
        x = allList[index + 1]
        newVal, newNode = x
        if newVal == oldVal + 1:
            index = index + 1
            oldNode.children.append(take())
        else:
            break
    return oldNode


def search(node: AST_Node):
    def processAddr(list):
        if list is None:
            return "None"
        return ",".join([str(x.index) for x in list])

    print(" " + str(node.index) + " [label = \"{" + str(
        node.index) + ":" + node.spelling + " | <addr>addr:" + node.address, end=" ")
    if node.loc is not None and len(node.loc) != 0:
        print(" | <0>loc:" + node.loc, end="")

    if len(node.other) > 0:
        s_tmp = " ".join([x for x in node.other])
        # res = re.search("<(.*?)>",s_tmp)
        # if res:
        #     s_tmp = re.sub("<(.*?)>","\\<"+res.group(1)+"\\>",s_tmp)
        s_tmp = s_tmp.replace("<", "\\<")
        s_tmp = s_tmp.replace(">", "\\>")
        print(" | <other>" + s_tmp, end="")
    # for i, other in enumerate(node.other):
    #     print(" | <" + str(i + 1) + ">" + other, end="")
    print("}\"];")
    for child in node.children:
        print(" " + processAddr(addrMap[node.address]) + " -> " + processAddr(addrMap[child.address]), end=";\n")
    for i, other in enumerate(node.other):
        if re.search("0x[0-9a-zA-Z]*", other) and addrMap.get(other):
            print(" " + processAddr(addrMap[node.address]) + "-> " + processAddr(addrMap[other]), end=";\n")

    for child in node.children:
        search(child)


if __name__ == '__main__':
    savedStdout = sys.stdout  # 保存标准输出流
    inputFile = sys.argv[1]
    outputFile = sys.argv[2]

    print_log = open(outputFile, "w")
    sys.stdout = print_log

    with open(inputFile, "r") as f:
        allList = []
        # list=re.findall("0x[0-9a-zA-Z]*",f.read())
        # s=set(list)

        for line in f.readlines():
            if "<<<NULL>>>" in line:
                continue
            line = line.replace("<<invalid sloc>> <invalid sloc>", "")
            p = re.compile("<((.*?)(:[0-9]*)+)>")
            res = p.search(line)
            # print(line)
            if res:
                line = p.sub("\"<" + res.group(1) + ">\"", line)
            # print(line)
            # line = line.replace("<", "\"<")
            # line = line.replace(">", ">\"")

            # print(line)
            val, node = process(line)
            allList.append((val, node))

        for alist in allList:
            i, node = alist
            if addrMap.get(node.address) is None:
                addrMap[node.address] = []
            addrMap[node.address].append(node)

        # 调试
        # for t in allList:
        #     i, x = t
        #     print("loc:"+x.loc + " other:" + str(x.other))

        # 生成的树
        root = take()
        print("digraph G{\n node [shape = record]")
        search(root)
        print("}")
        print_log.close()


示例

对如下代码运行./start.sh test.c
test.c
删除input.txt中头文件的大多部分,只保留printf
再次运行./start.sh

示例C代码

#include <stdio.h>
int main(int argc,char*argv[])
{
	printf("hello world");
	return 0;
}

示例生成的svg结果

image