prevent loading tokenizer on import (#1432)

also adds sentencepiece dep for exe
moved vicuna imports to after an if statement
in general we should avoid importing files that load whole models as
global variables
This commit is contained in:
Daniel Garvey 2023-05-12 21:11:45 -05:00 committed by GitHub
parent 4c07e47e8c
commit 4731c1a835
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 20 deletions

View File

@ -399,10 +399,11 @@ def compile_vicuna_layer(
return ts_g
path = "TheBloke/vicuna-7B-1.1-HF"
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"):
kwargs = {"torch_dtype": torch.float}
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
return vicuna_model, tokenizer
def compile_to_vmfb(inputs, layers, is_first=True):
@ -577,7 +578,7 @@ def get_sharded_model():
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
global vicuna_model
vicuna_model = get_model_and_tokenizer()[0]
placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
@ -611,6 +612,7 @@ if __name__ == "__main__":
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
prologue_prompt = "ASSISTANT:\n"
sharded_model = get_sharded_model()
tokenizer = get_model_and_tokenizer()[1]
past_key_values = None
while True:
print("\n\n")

View File

@ -24,12 +24,6 @@ from shark.shark_inference import SharkInference
from pathlib import Path
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print(f"Sucessfully loaded the tokenizer to the memory")
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
@ -246,9 +240,12 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
"""
input_ids = torch.randint(3, (1, 256))
attention_mask = torch.randint(3, (1, 256))
sharkModel = 0
def get_tokenizer():
model_path = "stabilityai/stablelm-tuned-alpha-3b"
tok = AutoTokenizer.from_pretrained(model_path)
tok.add_special_tokens({"pad_token": "<PAD>"})
print(f"Sucessfully loaded the tokenizer to the memory")
return tok
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
@ -263,7 +260,12 @@ def generate(
num_beams,
stopping_criteria,
sharkStableLM,
tok=None,
input_ids=torch.randint(3, (1, 256)),
attention_mask=torch.randint(3, (1, 256)),
):
if tok == None:
tok = get_tokenizer
# Construct the input message string for the model by concatenating the current system message and conversation history
# Tokenize the messages string
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")

View File

@ -31,6 +31,7 @@ datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('shark')
datas += collect_data_files('tkinter')
datas += collect_data_files('webview')
datas += collect_data_files('sentencepiece')
datas += [
( 'src/utils/resources/prompts.json', 'resources' ),
( 'src/utils/resources/model_db.json', 'resources' ),

View File

@ -5,8 +5,7 @@ from apps.language_models.scripts.stablelm import (
compile_stableLM,
StopOnTokens,
generate,
sharkModel,
tok,
get_tokenizer,
StableLMModel,
)
from transformers import (
@ -15,10 +14,6 @@ from transformers import (
StoppingCriteriaList,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
get_sharded_model,
)
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
@ -49,6 +44,11 @@ def chat(curr_system_message, history, model):
global sharded_model
global past_key_values
if "vicuna" in model:
from apps.language_models.scripts.sharded_vicuna_fp32 import (
tokenizer,
get_sharded_model,
)
SAMPLE_INPUT_LEN = 137
curr_system_message = start_message_vicuna
if sharded_model == 0:
@ -100,6 +100,7 @@ def chat(curr_system_message, history, model):
global sharkModel
print("In chat")
if sharkModel == 0:
tok = get_tokenizer()
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
m = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32

View File

@ -28,6 +28,7 @@ scikit-image
pytorch_lightning # for runwayml models
tk
pywebview
sentencepiece
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile