Skip to content

Commit

Permalink
grammar suport for LLaMA
Browse files Browse the repository at this point in the history
  • Loading branch information
guinmoon committed Sep 26, 2023
1 parent af35ef1 commit 87d188d
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 45 deletions.
4 changes: 4 additions & 0 deletions LLMFarm/AIChatModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ final class AIChatModel: ObservableObject {
}
do{
if chat_config!["model_inference"] as! String == "llama"{
if (chat_config!["grammar"] != nil && chat_config!["grammar"] as! String != "<None>" && chat_config!["grammar"] as! String != ""){
let grammar_path = get_grammar_path_by_name(chat_config!["grammar"] as! String)
model_context_param.grammar_path = grammar_path
}
if modelURL.hasSuffix(".gguf"){
try model_load_res = self.chat?.loadModel(ModelInference.LLama_gguf,contextParams: model_context_param)
}else{
Expand Down
47 changes: 47 additions & 0 deletions LLMFarm/Lib/FileHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,53 @@ public func get_models_list() -> [Dictionary<String, String>]?{
return res
}

public func get_grammar_path_by_name(_ grammar_name:String) -> String?{
do {
let fileManager = FileManager.default
let documentsPath = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first
let destinationURL = documentsPath!.appendingPathComponent("grammars")
try fileManager.createDirectory (at: destinationURL, withIntermediateDirectories: true, attributes: nil)
let path = destinationURL.appendingPathComponent(grammar_name).path
if fileManager.fileExists(atPath: path){
return path
}else{
return nil
}

} catch {
print(error)
}
return nil
}

public func get_grammars_list() -> [String]?{
var res: [String] = []
res.append("<None>")
do {
// var gbnf_path=Bundle.main.resourcePath!.appending("/grammars")
// let gbnf_files = try FileManager.default.contentsOfDirectory(atPath: gbnf_path)
// for gbnf_file in gbnf_files {
// let tmp_chat_info = ["file_name":gbnf_file,"location":"res"]
// res.append(tmp_chat_info)
// }
let fileManager = FileManager.default
let documentsPath = fileManager.urls(for: .documentDirectory, in: .userDomainMask).first
let destinationURL = documentsPath!.appendingPathComponent("grammars")
try fileManager.createDirectory (at: destinationURL, withIntermediateDirectories: true, attributes: nil)
let files = try fileManager.contentsOfDirectory(atPath: destinationURL.path)
for gbnf_file in files {
if gbnf_file.hasSuffix(".gbnf"){
// let tmp_chat_info = ["file_name":gbnf_file,"location":"doc"]
res.append(gbnf_file)
}
}
return res
} catch {
// failed to read directory – bad permissions, perhaps?
}
return res
}

//func get_config_by_model_name(_ model_name:String) -> Dictionary<String, AnyObject>?{
// do {
//// let index = model_name.index(model_name.startIndex, offsetBy:model_name.count-4)
Expand Down
28 changes: 26 additions & 2 deletions LLMFarm/Settings/AddChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct AddChatView: View {
@State private var isImporting: Bool = false
@State private var tfs_z: Float = 1.0
@State private var typical_p: Float = 1.0
@State private var grammar: String = "<None>"
var hardware_arch = Get_Machine_Hardware_Name()
@Binding var renew_chat_list: () -> Void

Expand All @@ -106,6 +107,9 @@ struct AddChatView: View {

@State var models_previews = get_models_list()!

@State var grammars_previews = get_grammars_list()!


init(add_chat_dialog: Binding<Bool>,edit_chat_dialog:Binding<Bool>,
renew_chat_list: Binding<() -> Void>) {
self._add_chat_dialog = add_chat_dialog
Expand Down Expand Up @@ -201,6 +205,9 @@ struct AddChatView: View {
if (chat_config!["typical_p"] != nil){
self._typical_p = State(initialValue: chat_config!["typical_p"] as! Float)
}
if (chat_config!["grammar"] != nil){
self._grammar = State(initialValue: chat_config!["grammar"]! as! String)
}
}

func apply_setting_template(template:ModelSettingsTemplate){
Expand Down Expand Up @@ -277,7 +284,8 @@ struct AddChatView: View {
"mirostat_eta":mirostat_eta,
"mirostat_tau":mirostat_tau,
"tfs_z":tfs_z,
"typical_p":typical_p
"typical_p":typical_p,
"grammar":grammar
]
_ = create_chat(options,edit_chat_dialog:self.edit_chat_dialog,chat_name:self.chat_name)
if add_chat_dialog {
Expand Down Expand Up @@ -412,6 +420,21 @@ struct AddChatView: View {
.padding(.horizontal)
.padding(.top, 8)

if model_inference == "llama"{
HStack{
Text("Grammar sampling:")
.frame(maxWidth: .infinity, alignment: .leading)
Picker("", selection: $grammar) {
ForEach(grammars_previews, id: \.self) {
Text($0)
}
}
.pickerStyle(.menu)

}
.padding(.horizontal)
.padding(.top, 8)
}

DisclosureGroup("Prompt format:", isExpanded: $isPromptAccordionExpanded) {
Group {
Expand Down Expand Up @@ -531,6 +554,7 @@ struct AddChatView: View {

DisclosureGroup("Sampling options:", isExpanded: $isSamplingAccordionExpanded) {
Group {

HStack{
Text("Sampling:")
.frame(maxWidth: .infinity, alignment: .leading)
Expand Down Expand Up @@ -559,7 +583,7 @@ struct AddChatView: View {
}
.padding(.horizontal)
.padding(.top, 8)

if model_sampling == "temperature" {
Group {

Expand Down
6 changes: 6 additions & 0 deletions LLMFarm/grammars/arithmetic.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
42 changes: 42 additions & 0 deletions LLMFarm/grammars/c.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
root ::= (declaration)*

declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"

dataType ::= "int" ws | "float" ws | "char" ws
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*

parameter ::= dataType identifier

statement ::=
( dataType identifier ws "=" ws expression ";" ) |
( identifier ws "=" ws expression ";" ) |
( identifier ws "(" argList? ")" ";" ) |
( "return" ws expression ";" ) |
( "while" "(" condition ")" "{" statement* "}" ) |
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
( singleLineComment ) |
( multiLineComment )

forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
forUpdate ::= identifier ws "=" ws expression

condition ::= expression relationOperator expression
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")

expression ::= term (("+" | "-") term)*
term ::= factor(("*" | "/") factor)*

factor ::= identifier | number | unaryTerm | funcCall | parenExpression
unaryTerm ::= "-" factor
funcCall ::= identifier "(" argList? ")"
parenExpression ::= "(" ws expression ws ")"

argList ::= expression ("," ws expression)*

number ::= [0-9]+

singleLineComment ::= "//" [^\n]* "\n"
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"

ws ::= ([ \t\n]+)
13 changes: 13 additions & 0 deletions LLMFarm/grammars/chess.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Specifies chess moves as a list in algebraic notation, using PGN conventions

# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern
root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+
move ::= (pawn | nonpawn | castle) [+#]?

# piece type, optional file/rank, optional capture, dest file & rank
nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8]

# optional file & capture, dest file & rank, optional promotion
pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])?

castle ::= "O-O" "-O"?
7 changes: 7 additions & 0 deletions LLMFarm/grammars/japanese.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# A probably incorrect grammar for Japanese
root ::= jp-char+ ([ \t\n] jp-char+)*
jp-char ::= hiragana | katakana | punctuation | cjk
hiragana ::= [ぁ-ゟ]
katakana ::= [ァ-ヿ]
punctuation ::= [、-〾]
cjk ::= [一-鿿]
25 changes: 25 additions & 0 deletions LLMFarm/grammars/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
34 changes: 34 additions & 0 deletions LLMFarm/grammars/json_arr.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
# Useful for generating JSON arrays

root ::= arr
value ::= object | array | string | number | ("true" | "false" | "null") ws

arr ::=
"[\n" ws (
value
(",\n" ws value)*
)? "]"

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
4 changes: 4 additions & 0 deletions LLMFarm/grammars/list.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
root ::= item+

# Excludes various line break characters
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
67 changes: 25 additions & 42 deletions ModelTest/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func main(){
// modelInference = ModelInference.GPTNeox
////
//
// ai.modelPath = "/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/rp-incite-base-v1-3b-ggmlv3-q5_1.bin"
// modelInference = ModelInference.GPTNeox
// ai.modelPath = "/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/rp-incite-base-v1-3b-ggmlv3-q5_1.bin"
// modelInference = ModelInference.GPTNeox
//
// ai.modelPath = "/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/magicprompt-stable-diffusion-q5_1.bin"
// ai.modelPath = "/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/cerebras-2.7b-ggjtv3-q4_0.bin"
Expand All @@ -57,21 +57,24 @@ func main(){
// modelInference = ModelInference.Starcoder
// input_text = "def qsort"
//
// ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/q4_1-RWKV-4-Raven-1B5-v12-Eng.bin"
// ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/RWKV-4-MIDI-120M-v1-20230714-ctx4096-FP16.bin"
// ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/Sources/rwkv.cpp-master-8db73b1/tests/tiny-rwkv-660K-FP16.bin"
// modelInference = ModelInference.RWKV
// input_text = "song about love"
ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/q4_1-RWKV-4-Raven-1B5-v12-Eng.bin"
// // ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/RWKV-4-MIDI-120M-v1-20230714-ctx4096-FP16.bin"
// // ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/Sources/rwkv.cpp-master-8db73b1/tests/tiny-rwkv-660K-FP16.bin"
modelInference = ModelInference.RWKV
// input_text = "song about love"

// ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/orca-mini-3b.ggmlv3.q4_1.bin"
// ai.modelPath = "/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/llama-2-7b-chat-q4_K_M.gguf"
// ai.modelPath = "/Users/guinmoon/dev/alpaca_llama_etc/openllama-3b-v2-q8_0.gguf"
ai.modelPath = "/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/orca-mini-3b-q4_1.gguf"
modelInference = ModelInference.LLama_gguf


// ai.modelPath = "/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/orca-mini-3b-q4_1.gguf"
// modelInference = ModelInference.LLama_gguf
//
var params:ModelContextParams = .default
params.use_metal = true
//
// params.use_metal = true

// params.grammar_path = "/Users/guinmoon/dev/alpaca_llama_etc/LLMFarm/LLMFarm/grammars/list.gbnf"
input_text = "write to do list"

do{
try ai.loadModel(modelInference,contextParams: params)
Expand All @@ -80,28 +83,8 @@ func main(){
return
}

//// try? set_promt_format(ai: &ai)
// let exception = tryBlock {
//
//// try? ai.model.promptFormat = .LLaMa
//
// }
//
// if exception != nil {
// print(exception)
// exit(1)
// }
//
//

// ai.model.promptFormat = .Custom
// ai.model.custom_prompt_format = "Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:{{prompt}}### Response:"
////


// ai.model.contextParams.seed = 0;
// ai.model.promptStyle = .StableLM_Tuned


// if (!params.lora_adapter.empty()) {
// int err = llama_model_apply_lora_from_file(model,
Expand All @@ -126,20 +109,20 @@ func main(){
//### Response:
//"""
//
input_text = """
### User:
Tell more
### Response:
"""
// input_text = """
//### User:
//Tell more
//
//### Response:
//"""
// var tokens: [llama_token] = [Int32](repeating: 0, count: 256)
// var tokens_count:Int = 1
llama_load_state(ai.model.context,"/Users/guinmoon/dev/alpaca_llama_etc/dump_state_.bin")
// llama_load_state(ai.model.context,"/Users/guinmoon/dev/alpaca_llama_etc/dump_state_.bin")
// llama_load_session_file(ai.model.context,"/Users/guinmoon/dev/alpaca_llama_etc/dump_state.bin",tokens.mutPtr, 256,&tokens_count)
let prompt = input_text
let output = try? ai.model.predict(prompt, mainCallback)

let output = try? ai.model.predict(input_text, mainCallback)
// llama_save_session_file(ai.model.context,"/Users/guinmoon/dev/alpaca_llama_etc/dump_state.bin",ai.model.session_tokens, ai.model.session_tokens.count)
llama_save_state(ai.model.context,"/Users/guinmoon/dev/alpaca_llama_etc/dump_state_.bin")
// llama_save_state(ai.model.context,"/Users/guinmoon/dev/alpaca_llama_etc/dump_state_.bin")
//
print(output!)
}
Expand Down

0 comments on commit 87d188d

Please sign in to comment.