手写 Django orm反向迁移 MySQL

发布时间 2023-06-12 16:20:52作者: 最冷不过冬夜

import pymysql, os

#### settings
db = {
    'NAME': '',
    'USER': '',
    'PASSWORD': '',
    'HOST': '',
    'PORT': '',
}
table_name_list = []  # 表名列表,如果为空则查询库中所有的表
address = ""  # 存放文件的位置,如果为空则为当前路径


#### settings_end

# 对结果集美化方法
def dictfetchall(cursor):
    # 获取游标描述
    desc = cursor.description
    return [
        dict(zip([col[0] for col in desc], row))
        for row in cursor.fetchall()
    ]


modelsType = {
    'nvarchar': "CharField",
    'varchar': "CharField",
    'char': "CharField",
    'int': "IntegerField",
    'decimal': "DecimalField",
    'datetime': "DateTimeField",
    'real': "FloatField",
    "varbinary": "CharField",
    'text': "TextField",
    'date': "DateField",
    "datetime2": "DateTimeField",
    "float": "FloatField",
    "bit": "BooleanField",
    "smallint": "IntegerField",
    "pk": "AutoField",
    "bigint": "BigIntegerField",
    "longtext": "TextField",
    "double": "IntegerField",
}

conn = pymysql.connect(
    host=db['HOST'], user=db['USER'], password=db['PASSWORD'], database=db['NAME'])
cur = conn.cursor()
if not cur:
    raise (NameError, "连接数据库失败")
if not table_name_list:
    cur.execute("show tables;")
    table_name_list = dictfetchall(cur)
    table_name_list = [i['Tables_in_brm_db'] for i in table_name_list]
print(len(table_name_list))
address = "models.py" if not address else os.path.join(address, 'models.py')
with open(address, mode="w", encoding="utf-8") as f:
    f.write('from django.db import models\n')
    tc = []
    for table_name in table_name_list:  # table_name    表名
        if table_name in tc:
            continue
        else:
            tc.append(table_name)
            sql = "desc %s;" % table_name
            cur.execute(sql)
            l_list = dictfetchall(cur)  # 这张表的数据
            print(l_list)
            table_modol = "\nclass %s(models.Model):" % "".join([i.capitalize() for i in table_name.split('_')])
            for l in l_list:
                l_modol_in = 'db_column = "%s"' % l['Field']
                print("l['Type']", l['Type'])
                lType = l['Type'].split('(')
                dataType = lType[0]
                dataLong = lType[1].replace(")", "") if len(lType) == 2 else 0
                print("数据类型, 长度", dataType, dataLong)
                if dataLong and modelsType[dataType] == "DecimalField":
                    l_modol_in += f",max_digits={dataLong.split(',')[0]},decimal_places={dataLong.split(',')[1]}"
                elif dataLong and int(dataLong) > 0 and modelsType[dataType] != "IntegerField" \
                        and dataType != "DateTimeField":
                    l_modol_in += ',max_length=%s' % (dataLong)
                if l['Key'] == "PRI":
                    dataType = 'pk'
                    l_modol_in += ' , primary_key=True'
                if l['Null'] == "yes":
                    l_modol_in += ' , blank=True,null=True'

                l_modol = "\n\t%s = models.%s(%s)" % (l['Field'], modelsType[dataType], l_modol_in)  # orm的每列字段
                table_modol += l_modol

            meta_info = "\n\tclass Meta:\n\t\tmanaged = False\n\t\tdb_table = '%s'\n\n" % table_name

            f.write(table_modol)
            f.write(meta_info)
conn.close()
print("models生成完成\n生成表为%s" % [i for i in tc])