CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/course/en/chapter7/section6_tf.ipynb
Views: 2555
Kernel: Unknown Kernel

Training a causal language model from scratch (TensorFlow)

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

!pip install datasets evaluate transformers[sentencepiece] !apt install git-lfs

You will need to setup git, adapt your email and name in the following cell.

!git config --global user.email "[email protected]" !git config --global user.name "Your Name"

You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials.

from huggingface_hub import notebook_login notebook_login()
def any_keyword_in_string(string, keywords): for keyword in keywords: if keyword in string: return True return False
filters = ["pandas", "sklearn", "matplotlib", "seaborn"] example_1 = "import numpy as np" example_2 = "import pandas as pd" print( any_keyword_in_string(example_1, filters), any_keyword_in_string(example_2, filters) )
False True
from collections import defaultdict from tqdm import tqdm from datasets import Dataset def filter_streaming_dataset(dataset, filters): filtered_dict = defaultdict(list) total = 0 for sample in tqdm(iter(dataset)): total += 1 if any_keyword_in_string(sample["content"], filters): for k, v in sample.items(): filtered_dict[k].append(v) print(f"{len(filtered_dict['content'])/total:.2%} of data after filtering.") return Dataset.from_dict(filtered_dict)
# This cell will take a very long time to execute, so you should skip it and go to # the next one! from datasets import load_dataset split = "train" # "valid" filters = ["pandas", "sklearn", "matplotlib", "seaborn"] data = load_dataset(f"transformersbook/codeparrot-{split}", split=split, streaming=True) filtered_data = filter_streaming_dataset(data, filters)
3.26% of data after filtering.
from datasets import load_dataset, DatasetDict ds_train = load_dataset("huggingface-course/codeparrot-ds-train", split="train") ds_valid = load_dataset("huggingface-course/codeparrot-ds-valid", split="validation") raw_datasets = DatasetDict( { "train": ds_train, # .shuffle().select(range(50000)), "valid": ds_valid, # .shuffle().select(range(500)) } ) raw_datasets
DatasetDict({ train: Dataset({ features: ['repo_name', 'path', 'copies', 'size', 'content', 'license'], num_rows: 606720 }) valid: Dataset({ features: ['repo_name', 'path', 'copies', 'size', 'content', 'license'], num_rows: 3322 }) })
for key in raw_datasets["train"][0]: print(f"{key.upper()}: {raw_datasets['train'][0][key][:200]}")
'REPO_NAME: kmike/scikit-learn' 'PATH: sklearn/utils/__init__.py' 'COPIES: 3' 'SIZE: 10094' '''CONTENT: """ The :mod:`sklearn.utils` module includes various utilites. """ from collections import Sequence import numpy as np from scipy.sparse import issparse import warnings from .murmurhash import murm LICENSE: bsd-3-clause'''
from transformers import AutoTokenizer context_length = 128 tokenizer = AutoTokenizer.from_pretrained("huggingface-course/code-search-net-tokenizer") outputs = tokenizer( raw_datasets["train"][:2]["content"], truncation=True, max_length=context_length, return_overflowing_tokens=True, return_length=True, ) print(f"Input IDs length: {len(outputs['input_ids'])}") print(f"Input chunk lengths: {(outputs['length'])}") print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")
Input IDs length: 34 Input chunk lengths: [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 117, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 41] Chunk mapping: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
def tokenize(element): outputs = tokenizer( element["content"], truncation=True, max_length=context_length, return_overflowing_tokens=True, return_length=True, ) input_batch = [] for length, input_ids in zip(outputs["length"], outputs["input_ids"]): if length == context_length: input_batch.append(input_ids) return {"input_ids": input_batch} tokenized_datasets = raw_datasets.map( tokenize, batched=True, remove_columns=raw_datasets["train"].column_names ) tokenized_datasets
DatasetDict({ train: Dataset({ features: ['input_ids'], num_rows: 16702061 }) valid: Dataset({ features: ['input_ids'], num_rows: 93164 }) })
from transformers import AutoTokenizer, TFGPT2LMHeadModel, AutoConfig config = AutoConfig.from_pretrained( "gpt2", vocab_size=len(tokenizer), n_ctx=context_length, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, )
model = TFGPT2LMHeadModel(config) model(model.dummy_inputs) # Builds the model model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= transformer (TFGPT2MainLayer multiple 124242432 ================================================================= Total params: 124,242,432 Trainable params: 124,242,432 Non-trainable params: 0 _________________________________________________________________
from transformers import DataCollatorForLanguageModeling tokenizer.pad_token = tokenizer.eos_token data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf")
out = data_collator([tokenized_datasets["train"][i] for i in range(5)]) for key in out: print(f"{key} shape: {out[key].shape}")
input_ids shape: (5, 128) attention_mask shape: (5, 128) labels shape: (5, 128)
tf_train_dataset = model.prepare_tf_dataset( tokenized_dataset["train"], collate_fn=data_collator, shuffle=True, batch_size=32, ) tf_eval_dataset = model.prepare_tf_dataset( tokenized_dataset["valid"], collate_fn=data_collator, shuffle=False, batch_size=32, )
from huggingface_hub import notebook_login notebook_login()
from transformers import create_optimizer import tensorflow as tf num_train_steps = len(tf_train_dataset) optimizer, schedule = create_optimizer( init_lr=5e-5, num_warmup_steps=1_000, num_train_steps=num_train_steps, weight_decay_rate=0.01, ) model.compile(optimizer=optimizer) # Train in mixed-precision float16 tf.keras.mixed_precision.set_global_policy("mixed_float16")
from transformers.keras_callbacks import PushToHubCallback callback = PushToHubCallback(output_dir="codeparrot-ds", tokenizer=tokenizer) model.fit(tf_train_dataset, validation_data=tf_eval_dataset, callbacks=[callback])
from transformers import pipeline course_model = TFGPT2LMHeadModel.from_pretrained("huggingface-course/codeparrot-ds") course_tokenizer = AutoTokenizer.from_pretrained("huggingface-course/codeparrot-ds") pipe = pipeline( "text-generation", model=course_model, tokenizer=course_tokenizer, device=0 )
txt = """\ # create some data x = np.random.randn(100) y = np.random.randn(100) # create scatter plot with x, y """ print(pipe(txt, num_return_sequences=1)[0]["generated_text"])
# create some data x = np.random.randn(100) y = np.random.randn(100) # create scatter plot with x, y plt.scatter(x, y) # create scatter
txt = """\ # create some data x = np.random.randn(100) y = np.random.randn(100) # create dataframe from x and y """ print(pipe(txt, num_return_sequences=1)[0]["generated_text"])
# create some data x = np.random.randn(100) y = np.random.randn(100) # create dataframe from x and y df = pd.DataFrame({'x': x, 'y': y}) df.insert(0,'x', x) for
txt = """\ # dataframe with profession, income and name df = pd.DataFrame({'profession': x, 'income':y, 'name': z}) # calculate the mean income per profession """ print(pipe(txt, num_return_sequences=1)[0]["generated_text"])
# dataframe with profession, income and name df = pd.DataFrame({'profession': x, 'income':y, 'name': z}) # calculate the mean income per profession profession = df.groupby(['profession']).mean() # compute the
txt = """ # import random forest regressor from scikit-learn from sklearn.ensemble import RandomForestRegressor # fit random forest model with 300 estimators on X, y: """ print(pipe(txt, num_return_sequences=1)[0]["generated_text"])
# import random forest regressor from scikit-learn from sklearn.ensemble import RandomForestRegressor # fit random forest model with 300 estimators on X, y: rf = RandomForestRegressor(n_estimators=300, random_state=random_state, max_depth=3) rf.fit(X, y) rf