2. 用自然的代码表达
step12 可变长参数(正向)
主要是解决多输入多输出问题
例如Add函数:
class Add(Function):
def forward(self, x0, x1):
y = x0 + x1
return y
def add(x0, x1):
return Add()(x0, x1)
对Function类的改造如下:
class Function:
def __call__(self, *inputs):
xs = [x.data for x in inputs]
ys = self.forward(*xs) # 使用*号解包
if not isinstance(ys, tuple): # 对非元组情况的额外处理
ys = (ys,)
outputs = [Variable(as_array(y)) for y in ys]
for output in outputs:
output.set_creator(self)
self.inputs = inputs # 保存输入的变量
self.outputs = outputs
return outputs if len(outputs) > 1 else outputs[0]
# 若输出是标量,则直接返回标量,否则以列表形式返回