mirror of https://github.com/nod-ai/SHARK.git
prevent loading tokenizer on import (#1432)
also adds sentencepiece dep for exe moved vicuna imports to after an if statement in general we should avoid importing files that load whole models as global variables
This commit is contained in:
parent
4c07e47e8c
commit
4731c1a835
|
@ -399,10 +399,11 @@ def compile_vicuna_layer(
|
|||
return ts_g
|
||||
|
||||
|
||||
path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
|
||||
def get_model_and_tokenizer(path="TheBloke/vicuna-7B-1.1-HF"):
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
|
||||
return vicuna_model, tokenizer
|
||||
|
||||
|
||||
def compile_to_vmfb(inputs, layers, is_first=True):
|
||||
|
@ -577,7 +578,7 @@ def get_sharded_model():
|
|||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
global vicuna_model
|
||||
vicuna_model = get_model_and_tokenizer()[0]
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
|
@ -611,6 +612,7 @@ if __name__ == "__main__":
|
|||
prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
sharded_model = get_sharded_model()
|
||||
tokenizer = get_model_and_tokenizer()[1]
|
||||
past_key_values = None
|
||||
while True:
|
||||
print("\n\n")
|
||||
|
|
|
@ -24,12 +24,6 @@ from shark.shark_inference import SharkInference
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print(f"Sucessfully loaded the tokenizer to the memory")
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
|
@ -246,9 +240,12 @@ system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
|||
"""
|
||||
|
||||
|
||||
input_ids = torch.randint(3, (1, 256))
|
||||
attention_mask = torch.randint(3, (1, 256))
|
||||
sharkModel = 0
|
||||
def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print(f"Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
|
@ -263,7 +260,12 @@ def generate(
|
|||
num_beams,
|
||||
stopping_criteria,
|
||||
sharkStableLM,
|
||||
tok=None,
|
||||
input_ids=torch.randint(3, (1, 256)),
|
||||
attention_mask=torch.randint(3, (1, 256)),
|
||||
):
|
||||
if tok == None:
|
||||
tok = get_tokenizer
|
||||
# Construct the input message string for the model by concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
|
|
|
@ -31,6 +31,7 @@ datas += collect_data_files('google-cloud-storage')
|
|||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('tkinter')
|
||||
datas += collect_data_files('webview')
|
||||
datas += collect_data_files('sentencepiece')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
|
|
|
@ -5,8 +5,7 @@ from apps.language_models.scripts.stablelm import (
|
|||
compile_stableLM,
|
||||
StopOnTokens,
|
||||
generate,
|
||||
sharkModel,
|
||||
tok,
|
||||
get_tokenizer,
|
||||
StableLMModel,
|
||||
)
|
||||
from transformers import (
|
||||
|
@ -15,10 +14,6 @@ from transformers import (
|
|||
StoppingCriteriaList,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from apps.language_models.scripts.sharded_vicuna_fp32 import (
|
||||
tokenizer,
|
||||
get_sharded_model,
|
||||
)
|
||||
|
||||
start_message = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
|
@ -49,6 +44,11 @@ def chat(curr_system_message, history, model):
|
|||
global sharded_model
|
||||
global past_key_values
|
||||
if "vicuna" in model:
|
||||
from apps.language_models.scripts.sharded_vicuna_fp32 import (
|
||||
tokenizer,
|
||||
get_sharded_model,
|
||||
)
|
||||
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
curr_system_message = start_message_vicuna
|
||||
if sharded_model == 0:
|
||||
|
@ -100,6 +100,7 @@ def chat(curr_system_message, history, model):
|
|||
global sharkModel
|
||||
print("In chat")
|
||||
if sharkModel == 0:
|
||||
tok = get_tokenizer()
|
||||
# sharkModel = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/disk/phaneesh/stablelm_3b_f32_cuda_2048_newflags.vmfb")
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", torch_dtype=torch.float32
|
||||
|
|
|
@ -28,6 +28,7 @@ scikit-image
|
|||
pytorch_lightning # for runwayml models
|
||||
tk
|
||||
pywebview
|
||||
sentencepiece
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
|
Loading…
Reference in New Issue