97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from typing import Union, Callable
|
|
|
|
|
|
class CustomGLU(nn.Module):
|
|
"""Custom Gated Linear Unit activation.
|
|
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
|
|
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
|
|
function (i.e. sigmoid, swish, etc.).
|
|
|
|
Args:
|
|
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
|
|
dim (int): the dimension on which to split the input. Default: -1
|
|
|
|
Shape:
|
|
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
|
dimensions
|
|
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
|
|
|
Examples::
|
|
>>> m = CustomGLU(nn.Sigmoid())
|
|
>>> input = torch.randn(4, 2)
|
|
>>> output = m(input)
|
|
"""
|
|
def __init__(self, activation: nn.Module, dim: int = -1):
|
|
super(CustomGLU, self).__init__()
|
|
self.dim = dim
|
|
self.activation = activation
|
|
|
|
def forward(self, x: Tensor):
|
|
assert x.shape[self.dim] % 2 == 0 # M = N / 2
|
|
a, b = torch.chunk(x, 2, dim=self.dim)
|
|
return a * self.activation(b)
|
|
|
|
|
|
class SwiGLU(CustomGLU):
|
|
"""SiLU Gated Linear Unit activation.
|
|
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
|
|
the first half of the input matrices, :math:`b` is the second half.
|
|
|
|
Args:
|
|
dim (int): the dimension on which to split the input. Default: -1
|
|
"""
|
|
def __init__(self, dim: int = -1):
|
|
super(SwiGLU, self).__init__(nn.SiLU(), dim)
|
|
|
|
|
|
class GeGLU(CustomGLU):
|
|
"""GeLU Gated Linear Unit activation.
|
|
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
|
|
the first half of the input matrices, :math:`b` is the second half.
|
|
|
|
Args:
|
|
dim (int): the dimension on which to split the input. Default: -1
|
|
"""
|
|
def __init__(self, dim: int = -1):
|
|
super(GeGLU, self).__init__(nn.GELU(), dim)
|
|
|
|
|
|
class ReGLU(CustomGLU):
|
|
"""ReLU Gated Linear Unit activation.
|
|
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
|
|
the first half of the input matrices, :math:`b` is the second half.
|
|
|
|
Args:
|
|
dim (int): the dimension on which to split the input. Default: -1
|
|
"""
|
|
def __init__(self, dim: int = -1):
|
|
super(ReGLU, self).__init__(nn.ReLU(), dim)
|
|
|
|
|
|
def get_activation_fn(
|
|
activation: Union[str, Callable[[Tensor], Tensor]]
|
|
) -> Union[str, Callable[[Tensor], Tensor]]:
|
|
"""Helper function to map an activation string to the activation class.
|
|
If the supplied activation is not a string that is recognized, the activation is passed back.
|
|
|
|
Args:
|
|
activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
|
|
"""
|
|
if isinstance(activation, str):
|
|
if activation == "reglu":
|
|
return ReGLU()
|
|
elif activation == "geglu":
|
|
return GeGLU()
|
|
elif activation == "swiglu":
|
|
return SwiGLU()
|
|
return activation
|