Flask源码分析系列之一

发布时间 2023-12-03 21:44:22作者: 你好aloha

Flask源码分析我打算写一个系列,这篇文章先讲讲Flask下开启服务的过程。

众所周知,Flask开启服务有两种方式,在v1.0之前只能通过Flask类提供的run()来开启服务,在v1.0之后Flask官方增加了通过命令方式来开启服务,即

flask main:app --host=0.0.0.0 --port=5000 --reload

本文先分析传统的run()方式,下一篇再分析通过命令方式开启服务的具体实现。

class Flask:
	def run(self, host=None, port=None, debug=None, load_dotenv=True, **options):
        options.setdefault("threaded", True)  # 是否以多线程方式开启服务

        from werkzeug.serving import run_simple

        try:
            run_simple(host, port, self, **options)
        finally:
            self._got_first_request = False

werkzeug组件提供的run_simple()方法

def run_simple(
    hostname,
    port,
    application,
    threaded=False,
    processes=1,
    request_handler=None
):
    def inner():
        try:
            fd = int(os.environ["WERKZEUG_SERVER_FD"])
        except (LookupError, ValueError):
            fd = None
        srv = make_server(  # 根据threaded和processes选择创建哪种web服务器
            hostname,
            port,
            application,
            threaded,         # 多线程web服务器
            processes,        # 多进程web服务器
            request_handler,  # 请求处理器
            fd=fd
        )
        srv.serve_forever()  # 作用:1.不停地处理请求 2.优雅退出
        # 说明:
        # 实例化WSGIRequestHandler类,从而调用handle(),并在里面调用run_wsgi()
        # 在run_wsgi()里面调用Flask.__call__()

    inner()

下面的make_server()方法根据threaded和processes两个参数来创建web服务器。
werkzeug提供了三种类型的Web服务器:

  • 多线程Web服务器
  • 多进程Web服务器
  • 单进程单线程Web服务器

1.单线程Web服务器

LISTEN_QUEUE = 128
class BaseWSGIServer(HTTPServer, object):  
    multithread = False
    multiprocess = False
    request_queue_size = LISTEN_QUEUE

    def __init__(
        self,
        host,
        port,
        app,
        handler=None,
    ):
        if handler is None:
            handler = WSGIRequestHandler  # 请求处理器

        self.address_family = select_address_family(host, port)

        server_address = get_sockaddr(host, int(port), self.address_family)

        # 注册请求处理器WSGIRequestHandler,这里不会进行实例化
        HTTPServer.__init__(self, server_address, handler)

        self.app = app
        self.shutdown_signal = False
        self.host = host
        self.port = self.socket.getsockname()[1]


    def serve_forever(self):  # 增加优雅退出
        self.shutdown_signal = False
        try:
        	# serve_forever -> _handle_request_noblock -> 
        	# process_request -> finish_request中实例化WSGIRequestHandler类
            HTTPServer.serve_forever(self)
        except KeyboardInterrupt:
            pass
        finally:
            self.server_close()

    def get_request(self):  # BaseServer需要子类提供该方法,继承链中父类HTTPServer没有给出实现,只有TCPServer提供了类似实现
        con, info = self.socket.accept()
        return con, info

2.多线程Web服务器

class ThreadedWSGIServer(ThreadingMixIn, BaseWSGIServer):  
    multithread = True
    daemon_threads = True


ThreadingMixIn = socketserver.ThreadingMixIn

3.多进程web服务器

can_fork = hasattr(os, "fork")
if can_fork:  # Linux系统
    ForkingMixIn = socketserver.ForkingMixIn
else:  # Windows系统
    class ForkingMixIn(object):
        pass

class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer):
    multiprocess = True

    def __init__(
        self,
        host,
        port,
        app,
        processes=40,
        handler=None,
        fd=None,
    ):
        if not can_fork:
            raise ValueError("Your platform does not support forking.")
        BaseWSGIServer.__init__(
            self, host, port, app, handler, fd
        )
        self.max_children = processes

以上三个类都是werkzeug组件提供的,他们基于底层模块封装了符合WSGI协议规范的内容。
我们可以往下看一看底层模块这块“砖头”给我们提供了什么?

先看一下vars()到底是个啥?

class X:
    def __init__(self):
        self.o = {'a': 1, 'b': 'xx'}

ret = vars(X())
print(ret)  # {'o': {'a': 1, 'b': 'xx'}}
print(X().__dict__)
print(vars())

vars()相当于locals(),vars(对象)相当于对象.__dict__

下面我们来看底层模块的实现:

# 主要有三块内容:
# 1.服务器除非Ctrl-C,否则不会退出(serve_forever),基于轮转的handle_request
# 2.连接管理(TCP/UDP协议)
# 3.支持多线程及多进程的ThreadingMixIn和ForkingMixIn
class _Threads(list):
    def append(self, thread):
        self.reap()
        if thread.daemon:
            return
        super().append(thread)

    def pop_all(self):  # 都取出所有线程,执行之后self上不再有线程
        self[:], result = [], self[:]
        return result

    def join(self):  # 所有子线程挂掉后主线程才退出
        for thread in self.pop_all():
            thread.join()

    def reap(self):  # 收集所有存活的线程
        self[:] = (thread for thread in self if thread.is_alive())


class ThreadingMixIn:
    daemon_threads = False
    block_on_close = True
    _threads = _Threads()

    def process_request_thread(self, request, client_address):
        try:
            self.finish_request(request, client_address)  # 处理请求
        except Exception:
            self.handle_error(request, client_address)    # 处理异常
        finally:
            self.shutdown_request(request)                # 关闭连接

    def process_request(self, request, client_address):
        if self.block_on_close:
            vars(self).setdefault('_threads', _Threads())  # 在当前对象上添加存放线程的_threads属性,只会设置一次
        t = threading.Thread(target = self.process_request_thread,
                             args = (request, client_address))
        t.daemon = self.daemon_threads                     # 设置为后台线程
        self._threads.append(t)                            # [线程1,线程2,...]
        t.start()                                          # 开启线程

    def server_close(self):
        super().server_close()
        self._threads.join()


if hasattr(os, "fork"):  # Linux系统
    class ForkingMixIn:
        timeout = 300
        active_children = None
        max_children = 40
        block_on_close = True

        def collect_children(self, *, blocking=False):
            if self.active_children is None:
                return

            while len(self.active_children) >= self.max_children:  # 子进程太多了
                try:
                    pid, _ = os.waitpid(-1, 0)  # 父进程等待所有子进程退出并回收子进程资源
                    self.active_children.discard(pid)
                except ChildProcessError:  # 子进程上所有操作失败
                    self.active_children.clear()
                except OSError:
                    break

            # Now reap all defunct children.
            for pid in self.active_children.copy():  # 子进程挂掉了
                try:
                    flags = 0 if blocking else os.WNOHANG  # WNOHANG:执行waitpid函数会立即返回
                    pid, _ = os.waitpid(pid, flags)
                    self.active_children.discard(pid)
                except ChildProcessError:
                    self.active_children.discard(pid)
                except OSError:
                    pass

        def handle_timeout(self):
            self.collect_children()

        def service_actions(self):
            self.collect_children()

        def process_request(self, request, client_address):
            pid = os.fork()  # 创建子进程
            # 两个进程同时进入一个函数COW
            if pid:  # 父进程走这里,父进程做记录及连接管理
                # Parent process
                if self.active_children is None:
                    self.active_children = set()
                self.active_children.add(pid)
                self.close_request(request)
                return
            else:  # 子进程走这里,退出要用os._exit(),子进程执行具体业务逻辑
                # Child process.
                status = 1
                try:
                    self.finish_request(request, client_address)
                    status = 0
                except Exception:
                    self.handle_error(request, client_address)
                finally:
                    try:
                        self.shutdown_request(request)
                    finally:
                        os._exit(status)

        def server_close(self):
            super().server_close()
            self.collect_children(blocking=self.block_on_close)



class HTTPServer(socketserver.TCPServer):
    allow_reuse_address = 1

    def server_bind(self):
        socketserver.TCPServer.server_bind(self)
        host, port = self.server_address[:2]
        self.server_name = socket.getfqdn(host)
        self.server_port = port


class ThreadingHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
    daemon_threads = True



class TCPServer(BaseServer):
    """连接管理"""
    address_family = socket.AF_INET

    socket_type = socket.SOCK_STREAM

    request_queue_size = 5

    allow_reuse_address = False

    def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
        BaseServer.__init__(self, server_address, RequestHandlerClass)  # 传入IP,端口和请求处理器
        self.socket = socket.socket(self.address_family,
                                    self.socket_type)  # 创建套接字
        if bind_and_activate:
            try:
                self.server_bind()      # bind,绑定套接字
                self.server_activate()  # listen,监听连接请求
            except:
                self.server_close()     # 关闭套接字
                raise

    def server_bind(self):
        if self.allow_reuse_address:  # 端口复用
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind(self.server_address)
        self.server_address = self.socket.getsockname()

    def server_activate(self):
        self.socket.listen(self.request_queue_size)

    def server_close(self):
        self.socket.close()

    def fileno(self):
        return self.socket.fileno()

    def get_request(self):  # 获取客户连接请求
        return self.socket.accept()

    def shutdown_request(self, request):  # 关闭连接
        try:
            request.shutdown(socket.SHUT_WR)
        except OSError:
            pass
        self.close_request(request)

    def close_request(self, request):
        request.close()




class BaseServer:
    """主要提供了serve_forever()及handle_request()方法"""
    timeout = None

    def __init__(self, server_address, RequestHandlerClass):
        self.server_address = server_address
        self.RequestHandlerClass = RequestHandlerClass
        self.__is_shut_down = threading.Event()
        self.__shutdown_request = False

    def server_activate(self):
        pass

    def serve_forever(self, poll_interval=0.5):
        self.__is_shut_down.clear()
        try:
            with _ServerSelector() as selector:
                selector.register(self, selectors.EVENT_READ)  # io多路复用,注册读连接

                while not self.__shutdown_request:
                    ready = selector.select(poll_interval)  # 0.5秒轮转一次,检查是否准备好数据
                    if self.__shutdown_request:  # 断开连接
                        break
                    if ready:
                        self._handle_request_noblock()  # 处理请求

                    self.service_actions()
        finally:
            self.__shutdown_request = False  # 本次处理完了,下次还能进来
            self.__is_shut_down.set()  # 告诉wait(),你可以关闭连接了

    def shutdown(self):
        self.__shutdown_request = True
        self.__is_shut_down.wait()  # 一直等着直到set()执行到

    def service_actions(self):
        pass

    def handle_request(self):
        timeout = self.socket.gettimeout()
        if timeout is None:
            timeout = self.timeout
        elif self.timeout is not None:
            timeout = min(timeout, self.timeout)
        if timeout is not None:
            deadline = time() + timeout

        with _ServerSelector() as selector:
            selector.register(self, selectors.EVENT_READ)

            while True:
                ready = selector.select(timeout)
                if ready:
                    return self._handle_request_noblock()
                else:
                    if timeout is not None:
                        timeout = deadline - time()
                        if timeout < 0:
                            return self.handle_timeout()

    def _handle_request_noblock(self):  # 处理请求
        try:
            request, client_address = self.get_request()
        except OSError:
            return
        if self.verify_request(request, client_address):
            try:
                self.process_request(request, client_address)
            except Exception:
                self.handle_error(request, client_address)
                self.shutdown_request(request)
            except:
                self.shutdown_request(request)
                raise
        else:
            self.shutdown_request(request)

    def handle_timeout(self):
        pass

    def verify_request(self, request, client_address):
        return True

    def process_request(self, request, client_address):
        self.finish_request(request, client_address)
        self.shutdown_request(request)

    def server_close(self):
        pass

    def finish_request(self, request, client_address):
        self.RequestHandlerClass(request, client_address, self)  # 实例化子类,执行里面的handle()逻辑实现请求的处理

    def shutdown_request(self, request):
        self.close_request(request)

    def close_request(self, request):
        pass

    # ================================== 上下文管理 ==================================
    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.server_close()

有了以上这些功能,服务器可以起来了,但现在还不能处理请求,处理请求要用到WSGIRequestHandler这个类提供的功能。

# BaseRequestHandler不是由werkzeug提供的
class BaseRequestHandler:  # WSGIRequestHandler也是走这套逻辑,即遵循setup()、handle()、finish()三部曲
    def __init__(self, request, client_address, server):
        self.request = request
        self.client_address = client_address
        self.server = server
        self.setup()  # 当前有没有要读的和要写的
        try:
            self.handle()  # 处理请求
        finally:
            self.finish()  # 还没写完一次刷出然后关闭读写


class WSGIRequestHandler(BaseHTTPRequestHandler, object):  # WSGIRequestHandler何时实例化
	"""
	WSGIRequestHandler不直接继承自BaseRequestHandler,该类只描述了处理请求的流程,要由子类提供实现,这里仍然采用这个流程
	覆写要做的事情——实现wsgi协议
	1.设置environ
	2.决定start_response要做的事情,具体工作写在write()里
	以上内容体现在run_wsgi()中,通过WSGIRequestHandler实例化来执行run_wsgi()
	"""
    def make_environ(self):
        environ = {
        	# 一堆预定义的wsgi环境变量
        }

        for key, value in self.get_header_items():#请求头
            key = key.upper().replace("-", "_")
            value = value.replace("\r\n", "")
            if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"):
                key = "HTTP_" + key
                if key in environ:
                    value = "{},{}".format(environ[key], value)
            environ[key] = value

        return environ

    def run_wsgi(self):  # 实现wsgi协议(在这里真正处理请求)
        self.environ = environ = self.make_environ()
        headers_set = []
        headers_sent = []

        def write(data):  # 返回响应数据(响应行、响应头、响应体)
            assert headers_set, "write() before start_response"
            if not headers_sent:
            	# 1.设置响应行
                status, response_headers = headers_sent[:] = headers_set
                try:
                    code, msg = status.split(None, 1)
                except ValueError:
                    code, msg = status, ""
                code = int(code)
                self.send_response(code, msg)

                # 2.设置响应头
                header_keys = set()
                for key, value in response_headers:
                    self.send_header(key, value)  
                    key = key.lower()
                    header_keys.add(key)
                if not (
                    "content-length" in header_keys
                    or environ["REQUEST_METHOD"] == "HEAD"
                    or code < 200
                    or code in (204, 304)
                ):
                    self.close_connection = True
                    self.send_header("Connection", "close")
                if "server" not in header_keys:
                    self.send_header("Server", self.version_string())
                if "date" not in header_keys:
                    self.send_header("Date", self.date_time_string())
                self.end_headers()

            # 3.设置响应体
            assert isinstance(data, bytes), "applications must write bytes"
            if data:
                self.wfile.write(data)
            self.wfile.flush()

        def start_response(status, response_headers, exc_info=None):
            headers_set[:] = [status, response_headers]
            return write

        def execute(app):
            application_iter = app(environ, start_response)  # 这里调用了Flask.__call__(),这里其实可以用任何框架!
            try:
                for data in application_iter:
                    write(data)
                if not headers_sent: # 断开连接
                    write(b"")
            finally:
                if hasattr(application_iter, "close"):
                    application_iter.close()

        execute(self.server.app)
       

    def handle(self):
        try:
        	# handle():不断调用handle_one_request(),只要有请求就处理
            BaseHTTPRequestHandler.handle(self)  # 这里调用handle_one_request(),进而调用run_wsgi()
        except (_ConnectionError, socket.timeout) as e:
            self.connection_dropped(e)
        except Exception as e:
            if self.server.ssl_context is None or not is_ssl_error(e):
                raise

    def handle_one_request(self):
        self.raw_requestline = self.rfile.readline()
        if not self.raw_requestline:  # 空包
            self.close_connection = 1
        elif self.parse_request():  # 数据包非空
            return self.run_wsgi()

    def send_response(self, code, message=None):  # 返回响应
        if self.request_version != "HTTP/0.9":  # 现在一般HTTP/1.1或h2
            hdr = "%s %d %s\r\n" % (self.protocol_version, code, message)
            self.wfile.write(hdr.encode("ascii"))

    def get_header_items(self):  # 获取请求头数据(字典)
        return self.headers.items()

由上述分析可知,启动流程如下:
1.默认开启多线程web服务器ThreadedWSGIServer
2.请求交给WSGIRequestHandler处理,处理器实例化后自动调用handle(),对于有效请求,调用run_wsgi()处理
3.在run_wsgi()中调用Flask.__call__()进而执行wsgi_app(),后面就是框架对请求的具体处理,如果不用Flask框架只借助werkzeug组件重写一套处理逻辑就要从这里开始!