快捷搜索:  汽车  科技

谷歌ai算法最新消息(谷歌AI开源BiT计算机视觉大规模预训练的探索)

谷歌ai算法最新消息(谷歌AI开源BiT计算机视觉大规模预训练的探索)wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5} Other models can be downloaded accordingly by plugging the name of the model (BiT-S or BiT-M) and architecture in the above command. Note that we provide models in two formats: npz (for PyTorch and Jax) and h5 (for TF2). By default we expect that model weights are stored in the root folder of this repository.For example if you wo

Big Transfer (BiT): General Visual Representation Learning

Introduction

在此存储库中,我们发布了来自Big Transfer(BiT):通用视觉表示学习论文的多个模型,这些模型已在ILSVRC-2012和ImageNet-21k数据集上进行了预训练。 我们提供的代码可以在主要的深度学习框架TensorFlow 2,PyTorch和Jax / Flax中微调已发布的模型。

我们希望计算机视觉社区将通过使用功能更强大的ImageNet-21k预训练模型(而不是在ILSVRC-2012数据集上预训练的常规模型)中受益。

我们还提供了用于更多探索性交互用途的合作实验室:TensorFlow 2合作实验室,PyTorch合作实验室和Jax合作实验室

Installation


Make sure you have Python>=3.6 installed on your machine.

To setup Tensorflow 2 PyTorch or Jax follow the instructions provided in the corresponding repository linked here.

In addition install python dependencies by running (please select tf2 pytorch or jax in the command below):

pip install -r bit_{tf2|pytorch|jax}/requirements.txt

How to fine-tune BiT


First download the BiT model. We provide models pre-trained on ILSVRC-2012 (BiT-S) or ImageNet-21k (BiT-M) for 5 different architectures: ResNet-50x1 ResNet-101x1 ResNet-50x3 ResNet-101x3 and ResNet-152x4.

For example if you would like to download the ResNet-50x1 pre-trained on ImageNet-21k run the following command:

wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}

Other models can be downloaded accordingly by plugging the name of the model (BiT-S or BiT-M) and architecture in the above command. Note that we provide models in two formats: npz (for PyTorch and Jax) and h5 (for TF2). By default we expect that model weights are stored in the root folder of this repository.

Then you can run fine-tuning of the downloaded model on your dataset of interest in any of the three frameworks. All frameworks share the command line interface

python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date %F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10

Currently. all frameworks will automatically download CIFAR-10 and CIFAR-100 datasets. Other public or custom datasets can be easily integrated: in TF2 and JAX we rely on the extensible tensorflow datasets library. In PyTorch we use torchvision’s data input pipeline.

Note that our code uses all available GPUs for fine-tuning.

We also support training in the low-data regime: the --examples_per_class <K> option will randomly draw K samples per class for training.

To see a detailed list of all available flags run python3 -m bit_{pytorch|jax|tf2}.train --help.

Available architectures

We release all architectures mentioned in the paper such that you may choose between accuracy or speed: R50x1 R101x1 R50x3 R101x3 R152x4. In the above path to the model file simply replace R50x1 by your architecture of choice.

We further investigated more architectures after the paper's publication and found R152x2 to have a nice trade-off between speed and accuracy hence we also include this in the release and provide a few numbers below.

Hyper-parameters

For reproducibility our training script uses hyper-parameters (BiT-HyperRule) that were used in the original paper. Note however that BiT models were trained and finetuned using Cloud TPU hardware so for a typical GPU setup our default hyper-parameters could require too much memory or result in a very slow progress. Moreover BiT-HyperRule is designed to generalize across many datasets so it is typically possible to devise more efficient application-specific hyper-parameters. Thus we encourage the user to try more light-weight settings as they require much less resources and often result in a similar accuracy.

For example we tested our code using a 8xV100 GPU machine on the CIFAR-10 and CIFAR-100 datasets while reducing batch size from 512 to 128 and learning rate from 0.003 to 0.001. This setup resulted in nearly identical performance (see Expected results below) in comparison to BiT-HyperRule despite being less computationally demanding.

Below we provide more suggestions on how to optimize our paper's setup.

Tips for optimizing memory or speed

The default BiT-HyperRule was developed on Cloud TPUs and is quite memory-hungry. This is mainly due to the large batch-size (512) and image resolution (up to 480x480). Here are some tips if you are running out of memory:

  1. In bit_hyperrule.py we specify the input resolution. By reducing it one can save a lot of memory and compute at the expense of accuracy.
  2. The batch-size can be reduced in order to reduce memory consumption. However one then also needs to play with learning-rate and schedule (steps) in order to maintain the desired accuracy.
  3. The PyTorch codebase supports a batch-splitting technique ("micro-batching") via --batch_split option. For example running the fine-tuning with --batch_split 8 reduces memory requirement by a factor of 8.

Expected results

We verified that when using the BiT-HyperRule the code in this repository reproduces the paper's results.


谷歌ai算法最新消息(谷歌AI开源BiT计算机视觉大规模预训练的探索)(1)


谷歌ai算法最新消息(谷歌AI开源BiT计算机视觉大规模预训练的探索)(2)


谷歌ai算法最新消息(谷歌AI开源BiT计算机视觉大规模预训练的探索)(3)

ImageNet results

These results were obtained using BiT-HyperRule. However because this results in large batch-size and large resolution memory can be an issue. The PyTorch code supports batch-splitting and hence we can still run things there without resorting to Cloud TPUs by adding the --batch_split N command where N is a power of two. For instance the following command produces a validation accuracy of 80.68 on a machine with 8 V100 GPUs:

python3 -m bit_pytorch.train --name ilsvrc_`date %F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4

Further increase to --batch_split 8 when running with 4 V100 GPUs etc.

Full results achieved that way in some test runs were:

谷歌ai算法最新消息(谷歌AI开源BiT计算机视觉大规模预训练的探索)(4)

Code https://github.com/google-research/big_transfer.git

Paper https://arxiv.org/abs/1912.11370

猜您喜欢: