Chen Shuo's Practical Network Programming - TTCP Lecture代码注释

发布时间 2023-11-15 19:59:06作者: Khun8

下面是C语言版本的TTCP,主要注释的是void receive(const Options& opt);函数,负责在服务器接收客户端发送的数据:

// muduo/examples/ace/ttcp/ttcp_blocking.cc

#include ...

// 接受新的TCP连接
static int acceptOrDie(uint16_t port)
{
  ...
}

// 完整的写N个字节
static int write_n(int sockfd, const void* buf, int length)
{
  int written = 0;
  while (written < length)
  {
    ssize_t nw = ::write(sockfd, static_cast<const char*>(buf) + written, length - written);
    if (nw > 0)
    {
      written += static_cast<int>(nw);
    }
    else if (nw == 0)
    {
      break;  // EOF
    }
    else if (errno != EINTR)
    {
      perror("write");
      break;
    }
  }
  return written;
}

// 完整的读N个字节
static int read_n(int sockfd, void* buf, int length)
{
  int nread = 0;
  while (nread < length)
  {
    ssize_t nr = ::read(sockfd, static_cast<char*>(buf) + nread, length - nread);
    if (nr > 0)
    {
      nread += static_cast<int>(nr);
    }
    else if (nr == 0)
    {
      break;  // EOF
    }
    else if (errno != EINTR)
    {
      perror("read");
      break;
    }
  }
  return nread;
}

// 发送(客户端)
void transmit(const Options& opt)
{
  ...
}

// 接收端(服务端)
void receive(const Options& opt)
{
  int sockfd = acceptOrDie(opt.port);

  // 准备读客户端发送过来的SessionMessage
  struct SessionMessage sessionMessage = { 0, 0 };
  // 检测是否读了完整的8个字节
  if (read_n(sockfd, &sessionMessage, sizeof(sessionMessage)) != sizeof(sessionMessage))
  {
    perror("read SessionMessage");
    exit(1);
  }

  // 网络字节序转成本机字节序(大端小端问题)
  sessionMessage.number = ntohl(sessionMessage.number); // Functions to convert between host and network byte order.
  sessionMessage.length = ntohl(sessionMessage.length); // Functions to convert between host and network byte order.
  // 打印出预计接收多少条数据,预计每条数据有多大
  printf("receive number = %d\nreceive length = %d\n",
         sessionMessage.number, sessionMessage.length);
  const int total_len = static_cast<int>(sizeof(int32_t) + sessionMessage.length); 
  // !!!漏洞!!!如果length很大,则malloc会分配可能超出可用范围的内存,造成拒绝响应攻击(但是因为TTCP是在内网使用的工具,所以这里没有严格约束)
  // 准备缓冲区来接收数据
  PayloadMessage* payload = static_cast<PayloadMessage*>(::malloc(total_len));
  // struct PayloadMessage
  // {
  //   int32_t length;
  //   char data[0]; // 不定长数组,放在结构体最后一个元素,运行时决定大小
  // };
  assert(payload);

  for (int i = 0; i < sessionMessage.number; ++i)
  {
    payload->length = 0;
    if (read_n(sockfd, &payload->length, sizeof(payload->length)) != sizeof(payload->length))
    {
      perror("read length");
      exit(1);
    }
    payload->length = ntohl(payload->length);
    assert(payload->length == sessionMessage.length);
    if (read_n(sockfd, payload->data, payload->length) != payload->length)
    {
      perror("read payload data");
      exit(1);
    }
    // 构造一个响应字节,告诉收到了多长的数据
    int32_t ack = htonl(payload->length);
    if (write_n(sockfd, &ack, sizeof(ack)) != sizeof(ack))
    {
      perror("write ack");
      exit(1);
    }
  }
  ::free(payload);
  ::close(sockfd);
}

对于客户端的发送数据的逻辑主要对C++版本的TTCP中,void transmit(const Options& opt);进行注释:

#include ...

struct Options
{
  uint16_t port;
  int length;
  int number;
  bool transmit, receive, nodelay;
  std::string host;
  Options()
    : port(0), length(0), number(0),
      transmit(false), receive(false), nodelay(false)
  {
  }
};

struct SessionMessage
{
  int32_t number;
  int32_t length;
} __attribute__ ((__packed__));

struct PayloadMessage
{
  int32_t length;
  char data[0];
};

double now()
{
  struct timeval tv = { 0, 0 };
  gettimeofday(&tv, NULL);
  return tv.tv_sec + tv.tv_usec / 1000000.0;
}

// FIXME: rewrite with getopt(3).
bool parseCommandLine(int argc, char* argv[], Options* opt)
{
  ...
}

// 发送(客户端)
void transmit(const Options& opt)
{
  InetAddress addr(opt.port);
  // 将主机名解析成IP地址
  if (!InetAddress::resolve(opt.host.c_str(), &addr))
  {
    printf("Unable to resolve %s\n", opt.host.c_str());
    return;
  }

  printf("connecting to %s\n", addr.toIpPort().c_str());
  // 根据IP地址连接
  // TcpStreamPtr是一个unique_ptr,指向TcpStream,因此不用手动关闭连接,超出作用域会自动关闭
  TcpStreamPtr stream(TcpStream::connect(addr)); // 按值返回,使用C++11的移动语义实现的,不用担心拷贝对象会造成的问题
  if (!stream)
  {
    printf("Unable to connect %s\n", addr.toIpPort().c_str());
    perror("");
    return;
  }


  // Nagle算法是一种由John Nagle在1984年提出的算法,旨在减少小数据包发送时的网络拥塞问题
  // 该算法通过将小的数据块缓冲并组合成更大的数据块来提高网络的效率
  // 具体来说,当应用程序发送的数据量较小时,Nagle算法会将数据缓存起来,直到以下两个条件之一满足时再发送数据:
  //   1. 接收方确认了之前发送的数据(确认ACK)
  //   2. 待发送的数据量超过了一个称为Nagle算法定时器的设定阈值
  // Nagle算法的目的是减少小数据包的发送次数,从而提高网络的效率。然而,它也会引入一定的延迟,因为数据发送方需要等待确认或达到定时器阈值才能发送数据
  // 与Nagle算法相对的是TCP的无延迟模式(TCP_NODELAY)
  // 无延迟模式的作用是禁用Nagle算法中的延迟确认机制,从而减少数据的传输延迟
  // 在无延迟模式下,TCP将立即发送数据,而不需要等待确认或达到定时器阈值

  if (opt.nodelay)
  {
    stream->setTcpNoDelay(true); // 设置TCP NODELAY,禁用掉NAGLE(在opt.nodelay为真时启用TCP的无延迟模式)
  }
  printf("connected\n");
  double start = now();
  struct SessionMessage sessionMessage = { 0, 0 };
  sessionMessage.number = htonl(opt.number);
  sessionMessage.length = htonl(opt.length);
  if (stream->sendAll(&sessionMessage, sizeof(sessionMessage)) != sizeof(sessionMessage))
  {
    /////////////////////////////////////////////////// short write /////////////////////////////////////////////////////
    // 在网络编程中,当使用TCP进行数据发送时,应用程序通常会将一定数量的数据写入发送缓冲区,并期望所有数据都能被完整发送。
    // 然而,由于网络条件、接收方的接收窗口大小等因素,有时可能发生短写(short write)的情况。
    // short write意味着实际写入的字节数少于应用程序请求的字节数。这可能发生在以下情况下:
    //  1. 发送缓冲区空间不足:发送缓冲区可能已满,无法容纳所有的数据,因此只有部分数据被写入。
    //  2. 网络拥塞:网络拥塞可能导致数据包丢失或延迟,从而造成short write。
    //  3. 接收方窗口大小限制:接收方的接收窗口大小限制了发送方能够发送的数据量,因此可能发生short write。
    // 在处理short write时,应用程序通常需要检查实际写入的字节数,并根据需要采取适当的处理措施,例如重新发送未完全发送的数据或调整发送策略。

    // 意味着连接可能断开了
    perror("write SessionMessage");
    return;
  }

  const int total_len = sizeof(int32_t) + opt.length;
  PayloadMessage* payload = static_cast<PayloadMessage*>(::malloc(total_len));
  // c++11 unique_ptr 
  // 超出作用域自动free
  std::unique_ptr<PayloadMessage, void (*)(void*)> freeIt(payload, ::free);
  assert(payload);
  payload->length = htonl(opt.length);
  for (int i = 0; i < opt.length; ++i)
  {
    payload->data[i] = "0123456789ABCDEF"[i % 16];
  }

  double total_mb = 1.0 * opt.length * opt.number / 1024 / 1024;
  printf("%.3f MiB in total\n", total_mb);

  for (int i = 0; i < opt.number; ++i)
  {
    int nw = stream->sendAll(payload, total_len);
    assert(nw == total_len);

    int ack = 0;
    int nr = stream->receiveAll(&ack, sizeof(ack));
    assert(nr == sizeof(ack)); // assert 简化的错误处理,不推荐使用
    ack = ntohl(ack);
    assert(ack == opt.length);
  }

  double elapsed = now() - start;
  // 计算带宽
  printf("%.3f seconds\n%.3f MiB/s\n", elapsed, total_mb / elapsed);
}

void receive(const Options& opt)
{
  ...
}

int main(int argc, char* argv[])
{
  Options options;
  if (parseCommandLine(argc, argv, &options))
  {
    if (options.transmit)
    {
      transmit(options);
    }
    else if (options.receive)
    {
      receive(options);
    }
    else
    {
      assert(0);
    }
  }
}