feat: add tokenization to model
This commit is contained in:
parent
12cec7f4e6
commit
3970073352
9 changed files with 4155 additions and 4 deletions
383
Cargo.lock
generated
383
Cargo.lock
generated
|
@ -2,6 +2,27 @@
|
||||||
# It is not intended for manual editing.
|
# It is not intended for manual editing.
|
||||||
version = 4
|
version = 4
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ahash"
|
||||||
|
version = "0.3.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aho-corasick"
|
||||||
|
version = "1.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "autocfg"
|
||||||
|
version = "1.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.2.7"
|
version = "1.2.7"
|
||||||
|
@ -20,6 +41,58 @@ dependencies = [
|
||||||
"cc",
|
"cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.8.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.9.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "csv"
|
||||||
|
version = "1.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf"
|
||||||
|
dependencies = [
|
||||||
|
"csv-core",
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "csv-core"
|
||||||
|
version = "0.1.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fasttext"
|
name = "fasttext"
|
||||||
version = "0.7.8"
|
version = "0.7.8"
|
||||||
|
@ -29,11 +102,239 @@ dependencies = [
|
||||||
"cfasttext-sys",
|
"cfasttext-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "96282e96bfcd3da0d3aa9938bedf1e50df3269b6db08b4876d2da0bb1a0841cf"
|
||||||
|
dependencies = [
|
||||||
|
"ahash",
|
||||||
|
"autocfg",
|
||||||
|
"rayon",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.8.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f56a2d0bc861f9165be4eb3442afd3c236d8a98afd426f65d92324ae1091a484"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.14.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itoa"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lab-intent-classifier"
|
name = "lab-intent-classifier"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"csv",
|
||||||
"fasttext",
|
"fasttext",
|
||||||
|
"itertools 0.14.0",
|
||||||
|
"rust-stemmers",
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
"stopwords",
|
||||||
|
"vtext",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazy_static"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matrixmultiply"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "916806ba0031cd542105d916a97c8572e1fa6dd79c9c51e7eb43a09ec2dd84c1"
|
||||||
|
dependencies = [
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memchr"
|
||||||
|
version = "2.7.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ndarray"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ac06db03ec2f46ee0ecdca1a1c34a99c0d188a0d83439b84bf0cb4b386e4ab09"
|
||||||
|
dependencies = [
|
||||||
|
"matrixmultiply",
|
||||||
|
"num-complex",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits 0.2.19",
|
||||||
|
"rawpointer",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-traits 0.2.19",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.46"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits 0.2.19",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.1.43"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits 0.2.19",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.92"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.38"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rawpointer"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
"rayon-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rayon-core"
|
||||||
|
version = "1.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-deque",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "1.11.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-automata"
|
||||||
|
version = "0.4.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.8.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rust-stemmers"
|
||||||
|
version = "1.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ryu"
|
||||||
|
version = "1.0.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "seahash"
|
||||||
|
version = "4.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde"
|
||||||
|
version = "1.0.217"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70"
|
||||||
|
dependencies = [
|
||||||
|
"serde_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_derive"
|
||||||
|
version = "1.0.217"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -41,3 +342,85 @@ name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sprs"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ec63571489873d4506683915840eeb1bb16b3198ee4894cc6f2fe3013d505e56"
|
||||||
|
dependencies = [
|
||||||
|
"ndarray",
|
||||||
|
"num-complex",
|
||||||
|
"num-traits 0.1.43",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "stopwords"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0e4508a6e132e6ea159112d42ed1f29927460dd45a118eed298d7666c81b713e"
|
||||||
|
dependencies = [
|
||||||
|
"lazy_static",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.95"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "1.0.69"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "1.0.69"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-segmentation"
|
||||||
|
version = "1.12.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "vtext"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7ec02229b562eef118ef9cba5ca50c406ccce6138bdc6b49b767b77fab648764"
|
||||||
|
dependencies = [
|
||||||
|
"hashbrown",
|
||||||
|
"itertools 0.8.2",
|
||||||
|
"lazy_static",
|
||||||
|
"ndarray",
|
||||||
|
"regex",
|
||||||
|
"seahash",
|
||||||
|
"serde",
|
||||||
|
"sprs",
|
||||||
|
"thiserror",
|
||||||
|
"unicode-segmentation",
|
||||||
|
]
|
||||||
|
|
|
@ -5,3 +5,10 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
fasttext = "0.7.8"
|
fasttext = "0.7.8"
|
||||||
|
vtext = "0.2.0"
|
||||||
|
rust-stemmers = "=1.2.0"
|
||||||
|
stopwords = "0.1.1"
|
||||||
|
itertools = "0.14.0"
|
||||||
|
csv = "1.3.1"
|
||||||
|
serde = "1.0.217"
|
||||||
|
serde_derive = "1.0.217"
|
||||||
|
|
1818
data/dataset.csv
Normal file
1818
data/dataset.csv
Normal file
File diff suppressed because it is too large
Load diff
1817
data/tokenized.txt
Normal file
1817
data/tokenized.txt
Normal file
File diff suppressed because it is too large
Load diff
44
src/data.rs
Normal file
44
src/data.rs
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
use crate::nlp;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug)]
|
||||||
|
struct FasttextPair {
|
||||||
|
label: String,
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn process_csv(input_file: &Path, output_file: &Path) -> Result<(), std::io::Error> {
|
||||||
|
let file = File::open(input_file)?;
|
||||||
|
let mut rdr = csv::Reader::from_reader(file);
|
||||||
|
|
||||||
|
let mut data: Vec<FasttextPair> = Vec::new();
|
||||||
|
|
||||||
|
for result in rdr.deserialize() {
|
||||||
|
let mut r: FasttextPair = result.expect("bruh");
|
||||||
|
r.text = nlp::tokenize(&r.text);
|
||||||
|
data.push(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
let formatted_output: Vec<String> = data
|
||||||
|
.into_iter()
|
||||||
|
.map(|pair| format!("{} {}", pair.label, pair.text))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
write_to_file(
|
||||||
|
output_file.to_str().expect("could not write output file"),
|
||||||
|
&formatted_output,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_to_file(filename: &str, lines: &[String]) -> io::Result<()> {
|
||||||
|
let mut file = File::create(filename)?;
|
||||||
|
for line in lines {
|
||||||
|
writeln!(file, "{}", line)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
14
src/main.rs
14
src/main.rs
|
@ -1,4 +1,9 @@
|
||||||
|
#[macro_use]
|
||||||
|
extern crate serde_derive;
|
||||||
|
|
||||||
|
mod data;
|
||||||
mod model;
|
mod model;
|
||||||
|
mod nlp;
|
||||||
|
|
||||||
use std::env::args;
|
use std::env::args;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
@ -17,6 +22,15 @@ fn cli() {
|
||||||
let dataset_path = args().nth(3).expect("no path to dataset given");
|
let dataset_path = args().nth(3).expect("no path to dataset given");
|
||||||
model::train(Path::new(&out_path), Path::new(&dataset_path)).unwrap();
|
model::train(Path::new(&out_path), Path::new(&dataset_path)).unwrap();
|
||||||
}
|
}
|
||||||
|
"tokenize" => {
|
||||||
|
let input_path_raw = args().nth(2).expect("no input path");
|
||||||
|
let output_path_raw = args().nth(3).expect("no output path");
|
||||||
|
|
||||||
|
let input_path = Path::new(&input_path_raw);
|
||||||
|
let output_path = Path::new(&output_path_raw);
|
||||||
|
|
||||||
|
data::process_csv(input_path, output_path).unwrap();
|
||||||
|
}
|
||||||
_ => panic!("subcommand does not exist"),
|
_ => panic!("subcommand does not exist"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
11
src/model.rs
11
src/model.rs
|
@ -1,3 +1,5 @@
|
||||||
|
extern crate serde;
|
||||||
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use fasttext::{Args, FastText, ModelName};
|
use fasttext::{Args, FastText, ModelName};
|
||||||
|
@ -15,9 +17,10 @@ pub fn train(out: &Path, dataset: &Path) -> Result<(), String> {
|
||||||
ft_model.save_model(out.to_str().unwrap())
|
ft_model.save_model(out.to_str().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn query(filename: &Path, query: &str) -> Vec<fasttext::Prediction> {
|
pub fn tokenize_and_query(filename: &Path, query: &str) -> Vec<fasttext::Prediction> {
|
||||||
let mut text = FastText::new();
|
let mut ft_model = FastText::new();
|
||||||
|
let query_tokenized = crate::nlp::tokenize(query);
|
||||||
|
|
||||||
let _ = text.load_model(filename.to_str().unwrap());
|
let _ = ft_model.load_model(filename.to_str().unwrap());
|
||||||
text.predict(query, 3, 0.3).unwrap()
|
ft_model.predict(&query_tokenized, 3, 0.7).unwrap()
|
||||||
}
|
}
|
||||||
|
|
32
src/nlp.rs
Normal file
32
src/nlp.rs
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
use itertools::Itertools;
|
||||||
|
use rust_stemmers::{Algorithm, Stemmer};
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use stopwords::{Language, Spark, Stopwords};
|
||||||
|
use vtext::tokenize::{Tokenizer, VTextTokenizer};
|
||||||
|
|
||||||
|
pub fn tokenize(text: &str) -> String {
|
||||||
|
// convert all to lowercase
|
||||||
|
let lc_text = text.to_lowercase();
|
||||||
|
|
||||||
|
// tokenise the words
|
||||||
|
let tok = VTextTokenizer::default();
|
||||||
|
let tokens: Vec<&str> = tok.tokenize(lc_text.as_str()).collect();
|
||||||
|
|
||||||
|
// stem the words
|
||||||
|
let en_stemmer = Stemmer::create(Algorithm::English);
|
||||||
|
let tokens: Vec<String> = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|x| en_stemmer.stem(x).into_owned())
|
||||||
|
.collect();
|
||||||
|
let mut tokens: Vec<&str> = tokens.iter().map(|x| x.as_str()).collect();
|
||||||
|
|
||||||
|
// remove the stopwords
|
||||||
|
let stops: HashSet<_> = Spark::stopwords(Language::English)
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.collect();
|
||||||
|
tokens.retain(|s| !stops.contains(s));
|
||||||
|
|
||||||
|
// join the tokens and return
|
||||||
|
tokens.iter().join(" ")
|
||||||
|
}
|
33
utils/convert_to_csv.js
Normal file
33
utils/convert_to_csv.js
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
const fs = require("fs");
|
||||||
|
const path = require("path");
|
||||||
|
|
||||||
|
// Input and output file paths
|
||||||
|
const inputFilePath = path.join(__dirname, "data/gpt-dataset.txt");
|
||||||
|
const outputFilePath = path.join(__dirname, "data/dataset.csv");
|
||||||
|
|
||||||
|
// Function to convert labeled data to CSV format
|
||||||
|
function convertToCSV(inputFile, outputFile) {
|
||||||
|
try {
|
||||||
|
// Read the input file
|
||||||
|
const data = fs.readFileSync(inputFile, "utf8");
|
||||||
|
|
||||||
|
// Use regex to extract labels and text
|
||||||
|
const matches = [...data.matchAll(/__(label__[^ ]+) (.+)/g)];
|
||||||
|
|
||||||
|
// Create CSV content
|
||||||
|
let csvContent = "Label,Text\n";
|
||||||
|
matches.forEach((match) => {
|
||||||
|
const [, label, text] = match;
|
||||||
|
csvContent += `${label},"${text.replace(/"/g, '""')}"\n`;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Write to the output CSV file
|
||||||
|
fs.writeFileSync(outputFile, csvContent);
|
||||||
|
console.log(`CSV file has been saved at ${outputFile}`);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the function
|
||||||
|
convertToCSV(inputFilePath, outputFilePath);
|
Loading…
Reference in a new issue