tortoise-orm 使用雪花算法生成主键ID

发布时间 2023-09-16 17:51:05作者: 时光煮酒丶
import time
from tortoise import Tortoise, fields, run_async
from tortoise.models import Model

from typing import Any


class Snowflake:
    def __init__(self, machine_id: int):
        """
        生成雪花算法ID
        :param machine_id: 机器ID
        """
        self.machine_id: int = machine_id
        self.sequence: int = 0
        self.last_timestamp: int = -1

    @staticmethod
    def _wait_next_millis(last_timestamp) -> int:
        timestamp = int(time.time() * 1000)
        while timestamp <= last_timestamp:
            timestamp = int(time.time() * 1000)
        return timestamp

    @property
    def generate_id(self) -> int:
        timestamp = int(time.time() * 1000)
        if timestamp < self.last_timestamp:
            raise Exception("Clock moved backwards")
        if timestamp == self.last_timestamp:
            self.sequence = (self.sequence + 1) & 4095
            if self.sequence == 0:
                timestamp = self._wait_next_millis(self.last_timestamp)
        else:
            self.sequence = 0
        self.last_timestamp = timestamp
        return ((timestamp - 1288834974657) << 22) | (self.machine_id << 12) | self.sequence


class SnowflakeField(fields.BigIntField):
    SQL_TYPE = "BIGINT UNSIGNED"
    allows_generated = True

    def __init__(self, pk: bool = False, **kwargs: Any) -> None:
        if pk:
            kwargs["generated"] = bool(kwargs.get("generated", True))
        super().__init__(pk=pk, **kwargs)

    @property
    def constraints(self) -> dict:
        return {
            "ge": 1 if self.generated or self.reference else 0,
            "le": 9223372036854775807,
        }

    class _db_mysql:
        GENERATED_SQL = "BIGINT UNSIGNED NOT NULL PRIMARY KEY AUTO_INCREMENT"


class SnowflakeIDGenerator:
    def __init__(self, model):
        self.model = model

    @staticmethod
    async def generate_id():
        snowflake = Snowflake(0)
        return snowflake.generate_id


class BaseModel(Model):
    id = SnowflakeField(pk=True, generated=False)

    @classmethod
    async def create(cls, **kwargs):
        kwargs["id"] = await SnowflakeIDGenerator(cls).generate_id()
        return await super().create(**kwargs)

    class Meta:
        abstract = True


class Event(BaseModel):
    name = fields.CharField(max_length=255)



async def run():
    await Tortoise.init(db_url="mysql://root:123456@192.168.1.28:3306/test?charset=utf8mb4",modules={"models": ["__main__"]})
    await Tortoise.generate_schemas()
    count = 1
    while True:
        await Event.create(name=f"Test_{count}")
        count += 1


if __name__ == "__main__":
    run_async(run())

while True死循环写入测试过程中会出现生成重复ID情况,(≧﹏ ≦)