import paddle
import paddle.nn as nn
paddle.set_device('cpu')
class PatchEembedding(nn.Layer):
def __init__(self,embed_dim,patch_size):
super().__init__()
self.embedding=nn.Conv2D(3,embed_dim,patch_size,patch_size)
def forward(self,x):
x=self.embedding(x)##
x=x.flatten(2)
x=x.transpose([0,2,1])
return x
class PatchMerging(nn.Layer):
def __init__(self,dim,input_resolution):
super().__init__()
self.resolution=input_resolution
self.dim=dim
self.reduction=nn.Linear(4*dim,2*dim)
self.norm=nn.LayerNorm(4*dim)
def forward(self,x):
h,w=self.resolution
B,_,C=x.shape
x=x.reshape([B,h,w,C])
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 0::2, 1::2, :]
x2 = x[:, 1::2, 0::2, :]
x3 = x[:, 1::2, 1::2, :]
x=paddle.concat([x0,x1,x2,x3],axis=-1)
x=x.reshape([B,-1,4*C])
x=self.norm(x)
x=self.reduction(x)
return x
class Mlp(nn.Layer):
def __init__(self,dim,mlp_ratios=4.0,dropout=0.):
super().__init__()
self.fc1=nn.Linear(dim,int(dim*mlp_ratios))
self.fc2=nn.Linear(int(dim*mlp_ratios),dim)
self.act=nn.GELU()
self.dropout=nn.Dropout(dropout)
def forward(self,x):
x=self.fc1(x)
x=self.act(x)
x=self.dropout(x)
x=self.fc2(x)
x=self.dropout(x)
return x
def windows_partion(x,ws):
B,H,W,C=x.shape
x=x.reshape([B,H//ws,ws,W//ws,ws,C])
x=x.transpose([0,1,3,2,4,5])
x=x.reshape([-1,ws,ws,C])
return x
def windows_reserver(windows,ws,H,W): ##反向推
B=int(windows.shape[0]//(H//ws * W//ws))
x=windows.reshape([B,H//ws,W//ws,ws,ws,-1])
x=x.transpose([0,1,3,2,4,5])
x=x.reshape([B,H,W,-1])
return x
class WindowsAttention(nn.Layer):
def __init__(self,dim,ws,num_heads):
super().__init__()
self.dim=dim
self.num_heads=num_heads
self.dim_head=dim//num_heads
self.qkv=nn.Linear(dim,3*dim)
self.scalar=self.dim_head**-0.5
self.softmax=nn.Softmax(-1)
self.pro=nn.Linear(dim,dim)
def transpose_matmul_head(self,x):
new_shape=x.shape[:-1]+[self.num_heads,self.dim_head]
x=x.reshape(new_shape)
x=x.transpose([0,2,1,3])
return x
def forward(self,x):
qkv=self.qkv(x).chunk(3,-1)
q,k,v=map(self.transpose_matmul_head,qkv)
attn=paddle.matmul(q,k,transpose_y=True)
attn=attn*self.scalar
attn=self.softmax(attn)
out=paddle.matmul(attn,v)
out=out.transpose([0,2,1,3])
out=out.flatten(2)
out=self.pro(out)
return out
class SwinBlock(nn.Layer):
def __init__(self,dim,ws,num_heads,input_resolution):
super().__init__()
self.dim=dim
self.ws=ws
self.num_heads=num_heads
self.resolution=input_resolution
self.attn_norm=nn.LayerNorm(dim)
self.attn=WindowsAttention(dim,ws,num_heads)
self.mlp_norm=nn.LayerNorm(dim)
self.mlp=Mlp(dim)
def forward(self,x):
H,W=self.resolution
B,N,C=x.shape
h=x
x=self.attn_norm(x)
x=x.reshape([B,H,W,C])
x_ws=windows_partion(x,self.ws)
x_ws=x_ws.reshape([-1,self.ws*self.ws,C])
x=self.attn(x_ws)
x=x.reshape([-1,self.ws,self.ws,C])
x=windows_reserver(x,self.ws,H,W)
x=x.reshape([B,H*W,C])
x=h+x
h=x
x=self.mlp_norm(x)
x=self.mlp(x)
x=h+x
return x
def main():
t=paddle.randn([4,3,224,224])
patch_embedding=PatchEembedding(patch_size=4,embed_dim=96)
swin_block=SwinBlock(dim=96,ws=7,num_heads=4,input_resolution=[56,56])
patch_merging=PatchMerging(input_resolution=[56,56],dim=96)
out=patch_embedding(t)
print('patch_embed out shape:',out.shape)
out1=swin_block(out)
print('swin_block out shape:',out1.shape)
out2=patch_merging(out1)
print('patch_merging out shape:',out2.shape)
if __name__=='__main__':
main()
import paddle.nn as nn
paddle.set_device('cpu')
class PatchEembedding(nn.Layer):
def __init__(self,embed_dim,patch_size):
super().__init__()
self.embedding=nn.Conv2D(3,embed_dim,patch_size,patch_size)
def forward(self,x):
x=self.embedding(x)##
x=x.flatten(2)
x=x.transpose([0,2,1])
return x
class PatchMerging(nn.Layer):
def __init__(self,dim,input_resolution):
super().__init__()
self.resolution=input_resolution
self.dim=dim
self.reduction=nn.Linear(4*dim,2*dim)
self.norm=nn.LayerNorm(4*dim)
def forward(self,x):
h,w=self.resolution
B,_,C=x.shape
x=x.reshape([B,h,w,C])
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 0::2, 1::2, :]
x2 = x[:, 1::2, 0::2, :]
x3 = x[:, 1::2, 1::2, :]
x=paddle.concat([x0,x1,x2,x3],axis=-1)
x=x.reshape([B,-1,4*C])
x=self.norm(x)
x=self.reduction(x)
return x
class Mlp(nn.Layer):
def __init__(self,dim,mlp_ratios=4.0,dropout=0.):
super().__init__()
self.fc1=nn.Linear(dim,int(dim*mlp_ratios))
self.fc2=nn.Linear(int(dim*mlp_ratios),dim)
self.act=nn.GELU()
self.dropout=nn.Dropout(dropout)
def forward(self,x):
x=self.fc1(x)
x=self.act(x)
x=self.dropout(x)
x=self.fc2(x)
x=self.dropout(x)
return x
def windows_partion(x,ws):
B,H,W,C=x.shape
x=x.reshape([B,H//ws,ws,W//ws,ws,C])
x=x.transpose([0,1,3,2,4,5])
x=x.reshape([-1,ws,ws,C])
return x
def windows_reserver(windows,ws,H,W): ##反向推
B=int(windows.shape[0]//(H//ws * W//ws))
x=windows.reshape([B,H//ws,W//ws,ws,ws,-1])
x=x.transpose([0,1,3,2,4,5])
x=x.reshape([B,H,W,-1])
return x
class WindowsAttention(nn.Layer):
def __init__(self,dim,ws,num_heads):
super().__init__()
self.dim=dim
self.num_heads=num_heads
self.dim_head=dim//num_heads
self.qkv=nn.Linear(dim,3*dim)
self.scalar=self.dim_head**-0.5
self.softmax=nn.Softmax(-1)
self.pro=nn.Linear(dim,dim)
def transpose_matmul_head(self,x):
new_shape=x.shape[:-1]+[self.num_heads,self.dim_head]
x=x.reshape(new_shape)
x=x.transpose([0,2,1,3])
return x
def forward(self,x):
qkv=self.qkv(x).chunk(3,-1)
q,k,v=map(self.transpose_matmul_head,qkv)
attn=paddle.matmul(q,k,transpose_y=True)
attn=attn*self.scalar
attn=self.softmax(attn)
out=paddle.matmul(attn,v)
out=out.transpose([0,2,1,3])
out=out.flatten(2)
out=self.pro(out)
return out
class SwinBlock(nn.Layer):
def __init__(self,dim,ws,num_heads,input_resolution):
super().__init__()
self.dim=dim
self.ws=ws
self.num_heads=num_heads
self.resolution=input_resolution
self.attn_norm=nn.LayerNorm(dim)
self.attn=WindowsAttention(dim,ws,num_heads)
self.mlp_norm=nn.LayerNorm(dim)
self.mlp=Mlp(dim)
def forward(self,x):
H,W=self.resolution
B,N,C=x.shape
h=x
x=self.attn_norm(x)
x=x.reshape([B,H,W,C])
x_ws=windows_partion(x,self.ws)
x_ws=x_ws.reshape([-1,self.ws*self.ws,C])
x=self.attn(x_ws)
x=x.reshape([-1,self.ws,self.ws,C])
x=windows_reserver(x,self.ws,H,W)
x=x.reshape([B,H*W,C])
x=h+x
h=x
x=self.mlp_norm(x)
x=self.mlp(x)
x=h+x
return x
def main():
t=paddle.randn([4,3,224,224])
patch_embedding=PatchEembedding(patch_size=4,embed_dim=96)
swin_block=SwinBlock(dim=96,ws=7,num_heads=4,input_resolution=[56,56])
patch_merging=PatchMerging(input_resolution=[56,56],dim=96)
out=patch_embedding(t)
print('patch_embed out shape:',out.shape)
out1=swin_block(out)
print('swin_block out shape:',out1.shape)
out2=patch_merging(out1)
print('patch_merging out shape:',out2.shape)
if __name__=='__main__':
main()