ViT kann in Keras nicht serialisiert und deserialisiert werden, wenn die Klassifizierung auf „True“ gesetzt istPython

Python-Programme
Anonymous
 ViT kann in Keras nicht serialisiert und deserialisiert werden, wenn die Klassifizierung auf „True“ gesetzt ist

Post by Anonymous »

Ich habe ViT wie folgt erstellt:

Code: Select all

@keras.utils.register_keras_serializable(package="ViT", name="ViT")
class ViT(keras.Model):

"""
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "

"""

def __init__(
self,
in_channels: int,
img_size: Union[Sequence[int], int],
patch_size: Union[Sequence[int], int],
hidden_size: int = 768,
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
proj_type: str = "conv",
pos_embed_type: str = "learnable",
classification: bool = False,
num_classes: int = 2,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
post_activation: Optional[str] = "Tanh",
qkv_bias: bool = False,
save_attn: bool = False,
**kwargs,
) -> None:

"""
Args:
in_channels (int): dimension of input channels.
img_size (Union[Sequence[int], int]): dimension of input image.
patch_size (Union[Sequence[int], int]): dimension of patch size.
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
num_layers (int, optional): number of transformer blocks. Defaults to 12.
num_heads (int, optional): number of attention heads. Defaults to 12.
proj_type (str, optional): patch embedding layer type. Defaults to "conv".
pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
post_activation (str, optional): add a final acivation function to the classification head
when `classification` is True. Default to "Tanh" for `layers.Activation("tanh")`.
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block.  Defaults to False.

"""

super().__init__(**kwargs)

self.in_channels = in_channels
self.img_size = img_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.mlp_dim = mlp_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.proj_type = proj_type
self.pos_embed_type = pos_embed_type
self.classification = classification
self.num_classes = num_classes
self.dropout_rate = dropout_rate
self.spatial_dims = spatial_dims
self.post_activation = post_activation
self.qkv_bias = qkv_bias
self.save_attn = save_attn

if not (0

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post