refactor: minor cleanups

This commit is contained in:
Youwen Wu 2025-01-07 20:56:15 -08:00
parent fcb53c9462
commit 12cec7f4e6
Signed by: youwen5
GPG key ID: 865658ED1FE61EC3
2 changed files with 33 additions and 30 deletions

View file

@ -1,41 +1,21 @@
use std::env;
mod model;
use std::env::args;
use std::path::Path;
use fasttext::{Args, FastText, ModelName};
fn train_model(out: &Path, dataset: &Path) -> Result<(), String> {
let mut args = Args::new();
args.set_input(dataset.to_str().unwrap()).unwrap();
args.set_model(ModelName::SUP);
args.set_lr(1.0);
args.set_epoch(25);
//args.set_loss(LossName::SOFTMAX);
let mut ft_model = FastText::new();
ft_model.train(&args).unwrap();
ft_model.save_model(out.to_str().unwrap())
}
fn test_model(filename: &Path, query: &str) -> Vec<fasttext::Prediction> {
let mut text = FastText::new();
let _ = text.load_model(filename.to_str().unwrap());
text.predict(query, 3, 0.3).unwrap()
}
fn cli() {
let pattern = env::args().nth(1).expect("no command given!");
let pattern = args().nth(1).expect("no command given!");
match pattern.as_str() {
"test" => {
let model_path = env::args().nth(2).expect("no path to model given");
let query = env::args().nth(3).expect("no query to run");
let model_path = args().nth(2).expect("no path to model given");
let query = args().nth(3).expect("no query to run");
println!("{:?}", test_model(Path::new(&model_path), &query));
println!("{:?}", model::query(Path::new(&model_path), &query));
}
"train" => {
let out_path = env::args().nth(2).expect("no path to model given");
let dataset_path = env::args().nth(3).expect("no path to dataset given");
train_model(Path::new(&out_path), Path::new(&dataset_path)).unwrap();
let out_path = args().nth(2).expect("no path to model given");
let dataset_path = args().nth(3).expect("no path to dataset given");
model::train(Path::new(&out_path), Path::new(&dataset_path)).unwrap();
}
_ => panic!("subcommand does not exist"),
}

23
src/model.rs Normal file
View file

@ -0,0 +1,23 @@
use std::path::Path;
use fasttext::{Args, FastText, ModelName};
pub fn train(out: &Path, dataset: &Path) -> Result<(), String> {
let mut args = Args::new();
args.set_input(dataset.to_str().unwrap()).unwrap();
args.set_model(ModelName::SUP);
args.set_lr(1.0);
args.set_epoch(25);
//args.set_loss(LossName::SOFTMAX);
let mut ft_model = FastText::new();
ft_model.train(&args).unwrap();
ft_model.save_model(out.to_str().unwrap())
}
pub fn query(filename: &Path, query: &str) -> Vec<fasttext::Prediction> {
let mut text = FastText::new();
let _ = text.load_model(filename.to_str().unwrap());
text.predict(query, 3, 0.3).unwrap()
}