数位dp通用模板 -- 记忆化搜索

发布时间 2024-01-07 10:50:33作者: 深渊之巅

 

class Solution:
    def countSpecialNumbers(self, n: int) -> int:
        s = str(n)

        '''
            返回从i开始填数字,i前面填的数字集合是mask,能构造出的特殊整数的个数
            is_limit 表示前面填的数字是否是n对应位上的,及下一个填的数字是否有限制,如果为false表示至多为9,否则至多为s[i]
            is_num 表示前面是否填了数字(是否跳过),若为true,当前位可以从0开始.
       is_limit 用来处理前一个位置填的数字对后一个的限制,is_num用来处理前导0
''' @cache def f(i: int, mask: int, is_limit: bool, is_num: bool) -> int: if i == len(s): return int(is_num) res = 0 if not is_num: res = f(i + 1, mask, False, False) low = 0 if is_num else 1 up = int(s[i]) if is_limit else 9 for d in range(low, up + 1): if (mask >> d & 1) == 0: res += f(i + 1, mask | (1 << d), is_limit and d == up, True) return res return f(0, 0, True, False)

 

 

 

class Solution:
    def atMostNGivenDigitSet(self, digits: List[str], n: int) -> int:
        s = str(n)

        @cache
        def f(i: int, is_limit: bool, is_num: bool) -> int:
            if i == len(s):
                return int(is_num)
            
            res = 0
            if not is_num:
                res = f(i + 1, False, False)
            up = s[i] if is_limit else '9'
            for d in digits:
                if d > up:
                    break
                res += f(i + 1, is_limit and d == up, True)
            return res
        return f(0, True, False)

 

 

 

class Solution:
    def countDigitOne(self, n: int) -> int:
        s = str(n)

        @cache
        def f(i: int, cnt: int, is_limit: bool) -> int:
            if i == len(s):
                return cnt
            low = 0
            res = 0
            up = int(s[i]) if is_limit else 9
            for d in range(low, up + 1):
                res += f(i + 1, cnt + (d == 1), is_limit and d == up)
            
            return res
        
        return f(0, 0, True)

 

 

 

class Solution:
    def numberOfPowerfulInt(self, start: int, finish: int, limit: int, s: str) -> int:
        n = len(s)
        
        @cache
        def f(i: int, is_limit: bool, t: str):
            if len(t) < n:
                return 0
            
            if len(t) - i == n:
                if is_limit:
                    return int(s <= t[i:])
                else:
                    return 1
            
            res = 0
            low = 0
            up = int(t[i]) if is_limit else 9
            up = min(up, limit)
            
            for d in range(low, up + 1):
                res += f(i + 1, is_limit and d == int(t[i]), t)
            
            return res
        
        return f(0, True, str(finish)) - f(0, True, str(start - 1))