refactor: minor cleanups
This commit is contained in:
parent
fcb53c9462
commit
12cec7f4e6
2 changed files with 33 additions and 30 deletions
40
src/main.rs
40
src/main.rs
|
@ -1,41 +1,21 @@
|
|||
use std::env;
|
||||
mod model;
|
||||
|
||||
use std::env::args;
|
||||
use std::path::Path;
|
||||
|
||||
use fasttext::{Args, FastText, ModelName};
|
||||
|
||||
fn train_model(out: &Path, dataset: &Path) -> Result<(), String> {
|
||||
let mut args = Args::new();
|
||||
args.set_input(dataset.to_str().unwrap()).unwrap();
|
||||
args.set_model(ModelName::SUP);
|
||||
args.set_lr(1.0);
|
||||
args.set_epoch(25);
|
||||
//args.set_loss(LossName::SOFTMAX);
|
||||
let mut ft_model = FastText::new();
|
||||
ft_model.train(&args).unwrap();
|
||||
|
||||
ft_model.save_model(out.to_str().unwrap())
|
||||
}
|
||||
|
||||
fn test_model(filename: &Path, query: &str) -> Vec<fasttext::Prediction> {
|
||||
let mut text = FastText::new();
|
||||
|
||||
let _ = text.load_model(filename.to_str().unwrap());
|
||||
text.predict(query, 3, 0.3).unwrap()
|
||||
}
|
||||
|
||||
fn cli() {
|
||||
let pattern = env::args().nth(1).expect("no command given!");
|
||||
let pattern = args().nth(1).expect("no command given!");
|
||||
match pattern.as_str() {
|
||||
"test" => {
|
||||
let model_path = env::args().nth(2).expect("no path to model given");
|
||||
let query = env::args().nth(3).expect("no query to run");
|
||||
let model_path = args().nth(2).expect("no path to model given");
|
||||
let query = args().nth(3).expect("no query to run");
|
||||
|
||||
println!("{:?}", test_model(Path::new(&model_path), &query));
|
||||
println!("{:?}", model::query(Path::new(&model_path), &query));
|
||||
}
|
||||
"train" => {
|
||||
let out_path = env::args().nth(2).expect("no path to model given");
|
||||
let dataset_path = env::args().nth(3).expect("no path to dataset given");
|
||||
train_model(Path::new(&out_path), Path::new(&dataset_path)).unwrap();
|
||||
let out_path = args().nth(2).expect("no path to model given");
|
||||
let dataset_path = args().nth(3).expect("no path to dataset given");
|
||||
model::train(Path::new(&out_path), Path::new(&dataset_path)).unwrap();
|
||||
}
|
||||
_ => panic!("subcommand does not exist"),
|
||||
}
|
||||
|
|
23
src/model.rs
Normal file
23
src/model.rs
Normal file
|
@ -0,0 +1,23 @@
|
|||
use std::path::Path;
|
||||
|
||||
use fasttext::{Args, FastText, ModelName};
|
||||
|
||||
pub fn train(out: &Path, dataset: &Path) -> Result<(), String> {
|
||||
let mut args = Args::new();
|
||||
args.set_input(dataset.to_str().unwrap()).unwrap();
|
||||
args.set_model(ModelName::SUP);
|
||||
args.set_lr(1.0);
|
||||
args.set_epoch(25);
|
||||
//args.set_loss(LossName::SOFTMAX);
|
||||
let mut ft_model = FastText::new();
|
||||
ft_model.train(&args).unwrap();
|
||||
|
||||
ft_model.save_model(out.to_str().unwrap())
|
||||
}
|
||||
|
||||
pub fn query(filename: &Path, query: &str) -> Vec<fasttext::Prediction> {
|
||||
let mut text = FastText::new();
|
||||
|
||||
let _ = text.load_model(filename.to_str().unwrap());
|
||||
text.predict(query, 3, 0.3).unwrap()
|
||||
}
|
Loading…
Reference in a new issue