Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.
Fine-tuning
Fine-tuning es un proceso de formación continua de modelos previamente entrenados para mejorar el rendimiento en casos de uso específicos.
Fine-tuning Los modelos pequeños que caben completamente en una sola GPU o aquellos que caben completamente en 8 copias del modelo en las CPU son sencillos. No se requiere ningún cambio especial con respecto al entrenamiento de FSDP ordinario. En el caso de modelos de mayor tamaño, hay que considerar la posibilidad de utilizar la función de inicialización diferida de parámetros, que puede resultar complicada.
Para solucionar este problema, la biblioteca de SMP carga el modelo completo en uno de los rangos, mientras que el resto crea modelos con ponderaciones vacías en un metadispositivo. A continuación, el PyTorch FSDP inicializa las ponderaciones de los rangos distintos de cero mediante la init_weights función y sincroniza las ponderaciones de todas las filas con las ponderaciones del rango 0 con el valor establecido en. sync_module_states True En el siguiente fragmento de código se muestra cómo debe configurarlo en su script de entrenamiento.
import torch.distributed as dist from transformers import AutoModelForCasalLM from accelerate import init_empty_weights from torch.sagemaker.delayed_param import DelayedParamIniter if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(..., low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) delayed_initer = DelayedParamIniter(model) model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if dist.get_rank() > 0 else None )
Fine-tuning un modelo de Hugging Face Transformer previamente entrenado con paralelismo tensorial SMP
En esta sección se analiza la carga de modelos de transformador para dos casos de uso: refinamiento de los modelos de transformador pequeños y ajuste de modelos de transformador grandes. Para modelos más pequeños sin demorar la inicialización de los parámetros, empaquete el modelo con la API antes de empaquetarlo con el FSDP. torch.sagemaker.transform PyTorch
import functools from transformers import AutoModelForCausalLM from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.sagemaker import transform model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", low_cpu_mem_usage=True) # Transform model while loading state dictionary from rank 0. tp_model = transform(model, load_state_dict_from_rank0=True) # Wrap with FSDP. model = FSDP( tp_model, ... sync_module_states=True, )
En el caso de los modelos mayores, el método anterior hace que se agote la memoria de la CPU. Le recomendamos que utilice la inicialización diferida de parámetros para evitar estos problemas de memoria de la CPU. En este caso, puede aplicar la API torch.sagemaker.transform y la API torch.sagemaker.delayed_param.DelayedParamIniter como se muestra en el siguiente código de ejemplo.
from transformers import AutoModelForCausalLM from torch.sagemaker import transform from torch.sagemaker.delayed_param import DelayedParamIniter # Create one instance of model without delayed param # on CPU, on one rank. if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(...,low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) # Transform model while loading state dictionary from rank 0 model = transform(model, load_state_dict_from_rank0=True) if dist.get_rank() != 0: # For fine-tuning, delayed parameter on non-zero ranks delayed_initer = DelayedParamIniter(model) else: delayed_initer = None with ( delayed_initer.validate_params_and_buffers_inited() if delayed_initer else nullcontext() ): # Wrap the model with FSDP model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if delayed_initer else None )