UNet(
  (t_proj): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): SiLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
  )
  (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): DownBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 32, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=64, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (residual_input_conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (down_sample_conv): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (1): DownBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=128, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 128, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (residual_input_conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (down_sample_conv): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (2): DownBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=256, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 256, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 256, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (residual_input_conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
      (down_sample_conv): Identity()
    )
  )
  (mids): ModuleList(
    (0): MidBlock(
      (resnet_conv_first): ModuleList(
        (0-1): 2 x Sequential(
          (0): GroupNorm(8, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (t_emb_layers): ModuleList(
        (0-1): 2 x Sequential(
          (0): SiLU()
          (1): Linear(in_features=128, out_features=256, bias=True)
        )
      )
      (resnet_conv_second): ModuleList(
        (0-1): 2 x Sequential(
          (0): GroupNorm(8, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (attention_norm): GroupNorm(8, 256, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (residual_input_conv): ModuleList(
        (0-1): 2 x Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (1): MidBlock(
      (resnet_conv_first): ModuleList(
        (0): Sequential(
          (0): GroupNorm(8, 256, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (1): Sequential(
          (0): GroupNorm(8, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (t_emb_layers): ModuleList(
        (0-1): 2 x Sequential(
          (0): SiLU()
          (1): Linear(in_features=128, out_features=128, bias=True)
        )
      )
      (resnet_conv_second): ModuleList(
        (0-1): 2 x Sequential(
          (0): GroupNorm(8, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (attention_norm): GroupNorm(8, 128, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (residual_input_conv): ModuleList(
        (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      )
    )
  )
  (ups): ModuleList(
    (0): UpBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 256, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=64, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (residual_input_conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (up_sample_conv): Identity()
    )
    (1): UpBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 128, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=32, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 32, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 32, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
      )
      (residual_input_conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
      (up_sample_conv): ConvTranspose2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
    (2): UpBlock(
      (resnet_conv_first): Sequential(
        (0): GroupNorm(8, 64, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (t_emb_layers): Sequential(
        (0): SiLU()
        (1): Linear(in_features=128, out_features=16, bias=True)
      )
      (resnet_conv_second): Sequential(
        (0): GroupNorm(8, 16, eps=1e-05, affine=True)
        (1): SiLU()
        (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (attention_norm): GroupNorm(8, 16, eps=1e-05, affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
      )
      (residual_input_conv): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (up_sample_conv): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    )
  )
  (norm_out): GroupNorm(8, 16, eps=1e-05, affine=True)
  (conv_out): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)