Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion clip_benchmark/clip_benchmark/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,9 @@ def has_kaggle():
return call('which kaggle', shell=True) == 0


def build_vtab_dataset(dataset_name, transform, download=True, split='test', data_dir='root', classnames=[]):
def build_vtab_dataset(dataset_name, transform, download=True, split='test', data_dir='root', classnames=None):
if classnames is None:
classnames = []
# Using VTAB splits instead of default TFDS splits
from .tfds import (VTABIterableDataset, disable_gpus_on_tensorflow,
download_tfds_dataset)
Expand Down
4 changes: 3 additions & 1 deletion internvl_chat/eval/mathvista/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def contains_number_word(text):
return False # If none of the words could be converted to a number, return False


def contains_quantity_word(text, special_keep_words=[]):
def contains_quantity_word(text, special_keep_words=None):
if special_keep_words is None:
special_keep_words = []
# check if text contains a quantity word
quantity_words = ['most', 'least', 'fewest'
'more', 'less', 'fewer',
Expand Down
12 changes: 9 additions & 3 deletions internvl_chat/internvl/model/internlm2/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,9 @@ def _reorder_cache(past_key_values, beam_idx):
)
return reordered_past

def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=''):
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, meta_instruction=''):
if history is None:
history = []
if tokenizer.add_bos_token:
prompt = ''
else:
Expand All @@ -1188,7 +1190,7 @@ def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
history: List[Tuple[str, str]] = None,
streamer: Optional[BaseStreamer] = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
Expand All @@ -1199,6 +1201,8 @@ def chat(
'- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.',
**kwargs,
):
if history is None:
history = []
inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
# also add end-of-assistant token in eos token id to avoid unnecessary generation
Expand All @@ -1224,7 +1228,7 @@ def stream_chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
history: List[Tuple[str, str]] = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
Expand All @@ -1237,6 +1241,8 @@ def stream_chat(
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
"""
if history is None:
history = []
if BaseStreamer is None:
raise ModuleNotFoundError(
'The version of `transformers` is too low. Please make sure '
Expand Down