import os, math
import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import torch
from torch.nn import functional as F
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE
strategy = "cpu fp32"
MODEL_NAME = "./RWKV-x060-World-7B-v3-20241112-ctx4096.pth"
print(f"Loading model - {MODEL_NAME}")
model = RWKV(model=MODEL_NAME, strategy=strategy)
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
def logprob(s):
tokens = [0] + pipeline.encode(s)
out, _ = model.forward(tokens, None, full_output=True)
logits = out[0: -1]
logprobs = (F.log_softmax(logits, dim=-1))[torch.arange(len(tokens)-1), tokens[1:]]
return float(torch.sum(logprobs)), list(zip(tokens[1:], [pipeline.decode([token]) for token in tokens[1:]], logprobs.tolist()))
def print_details(d):
for id, tok, logprob in d:
print("%6d\t%20s\t%2.3f\t%3.3f%%" % (id, tok.replace(' ', '_'), logprob, math.exp(logprob) * 100))
for k in ['4','5','6','7']:
prob, details = logprob(f"RWKV-{k}\n")
print(prob, math.exp(prob))
print_details(details)
1 Like