python并行之flask-socketio

发布时间 2023-12-02 23:09:16作者: jasonzhangxianrong

1、服务器端

from flask import *
from flask_socketio import *
from flask_socketio import SocketIO
from nasbench_lib.nasbench_201 import NASBench201
import random
import subprocess
class Server:
    def __init__(self, gpu):
        self.app = Flask(__name__)
        self.socketio = SocketIO(self.app, ping_timeout=3600000,
                                 ping_interval=3600000,
                                 max_http_buffer_size=int(1e32))
        self.gpu = gpu
        self.MIN_NUM_WORKERS = 1
        self.current_round = -1  # -1 for not yet started
        self.NUM_CLIENTS_CONTACTED_PER_ROUND = 1
        self.ready_client_sids = set()
        self.nas = NASBench201()
        # 设置 SocketIO 事件处理程序
        self.register_handles()
    def check_client_resource(self):
        self.client_resource = {}
        client_sids_selected = random.sample(list(self.ready_client_sids), self.NUM_CLIENTS_CONTACTED_PER_ROUND)
        for rid in client_sids_selected:
            emit('check_client_resource', {
                'round_number': self.current_round,
            }, room=rid)
    def register_handles(self):
        @self.socketio.on('connect')
        def handle_connect():
            print(request.sid, "connected")

        @self.socketio.on('reconnect')
        def handle_reconnect():
            print(request.sid, "reconnected")

        @self.socketio.on('disconnect')
        def handle_disconnect():
            print(request.sid, "disconnected")
            if request.sid in self.ready_client_sids:
                self.ready_client_sids.remove(request.sid)

        @self.socketio.on('client_wake_up')
        def handle_wake_up():
            print(f"服务器端被客户端{request.sid}唤醒.")
            emit('init')

        @self.socketio.on('client_ready')
        def handle_client_ready():
            print(f"服务器收到客户端{request.sid}准备完毕。开始check资源。")
            self.ready_client_sids.add(request.sid)
            if len(self.ready_client_sids) >= self.MIN_NUM_WORKERS:
                self.check_client_resource()
            else:
                print("没有足够的客户端连接....")

    def handle_connect(self):
        print("Client connected")

    def handle_sample(self):
        return self.nas.generate_random_for_multiview(10)

    def get_slurm_allocated_gpus(self):
        result = subprocess.run([
            'nvidia-smi',
            '--query-gpu=memory.total,memory.used',
            '--format=csv,nounits,noheader'
        ],stdout=subprocess.PIPE)
        output = result.stdout.decode('utf-8').strip().split('\n')
        gpu_0_memory = output[self.gpu].split(',')
        total_memory = int(gpu_0_memory[0])
        used_memory = int(gpu_0_memory[1])
        free_memory = total_memory - used_memory
        return free_memory

    def run(self, host='0.0.0.0', port=5000):
        self.socketio.run(self.app, host=host, port=port)

if __name__ == '__main__':
    server = Server(0)
    server.run()

 

2、客户端

import socketio
class Worker:
    def __init__(self, server_host, server_port):
        self.sio = socketio.Client()
        self.server_url = f'http://{server_host}:{server_port}'
        self.register_handles()
        self.connect_to_server()
    def on_init(self):
        print('客户端进行初始化.')
        #加载模型
        print("客户端本地模型加载完毕.")
        # ready to be dispatched for training
        self.sio.emit('client_ready')
    def register_handles(self):
        @self.sio.event
        def connect():
            print('客户端请求连接...')
            self.sio.emit("client_wake_up")

        @self.sio.event
        def disconnect():
            print('Disconnected')

        @self.sio.event
        def reconnect():
            print('Reconnected')

        @self.sio.on('check_client_resource')
        def on_check_client_resource(*args):
            self.sio.emit('check_client_resource_done')

        self.sio.on('init', self.on_init)

    def connect_to_server(self):
        print("Connecting to server...")
        self.sio.connect(self.server_url)
        self.sio.wait()
    def request_sample_architecture(self):
        self.sio.emit('sample', callback=self.print_response)

    def request_evolve_architectures(self, architectures):
        self.sio.emit('evolve', architectures, callback=self.print_response)
    @staticmethod
    def print_response(data):
        print("Response from server:", data)

 

3、管理端

from client import Worker
import torch.multiprocessing as mp

class GPUManager(object):
    def __init__(self):
        self.p_count = 10
        self.available_gpus = 6
        self.port_list = [
            78901,
            78902,
        ]

    def run_client(self):
        TIMEOUT = 48 * 3600
        proc = list()
        for i in range(self.p_count):
            worker = Worker("127.0.0.1", 5000)
            p = mp.Process(
                target=worker.connect_to_server,
            )
            proc.append(p)
        for p in proc:
            p.start()
        for p in proc:
            p.join(timeout=TIMEOUT)
        for p in proc:
            if p.is_alive():
                p.terminate()
                p.join()
        for p in proc:
            p.close()

    def run(self):
        worker = Worker("127.0.0.1", 5000)
        worker.connect_to_server()


if __name__ == '__main__':
    GPUManager().run()