From 9a486a52a027adafefa27b1d7ef6b62df3d873c8 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Thu, 2 Dec 2021 14:06:31 +0000 Subject: [PATCH] update multiwoz21 preprocess and add dbquery --- data/unified_datasets/multiwoz21/data.zip | Bin 12314671 -> 12314614 bytes data/unified_datasets/multiwoz21/database.py | 107 ++++++++++++++++++ .../unified_datasets/multiwoz21/preprocess.py | 9 +- 3 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 data/unified_datasets/multiwoz21/database.py diff --git a/data/unified_datasets/multiwoz21/data.zip b/data/unified_datasets/multiwoz21/data.zip index 51ac6b5f326b2c9c6d078c433c89b9e0010301cd..e320219217e24c6cdef07de99ab0cf6af88c7c59 100644 GIT binary patch delta 3371 zcmZ2~;XUiO^Y2-C1H748L>L$tI2aTPnl`^>xxmisUeGkX?#b>+%pL_z+dcGl>#BfQ z(}kbz-o*`5>|W5cy&`D0Unxjx`^n?Gw+nN+7BmG*H2-UyG`;ioZu$Dtc~3M%?tTwn zkdr3d`)udTc*(6wr<Sb>$a%~dGj+w)|90Yk4w<ptW>pFLt~A$V-!qv9Me;U(-f}zn zTWTAwJ#yLph4deBt_@RU(kE@o(g;p+Nttsk_T%N5VU~wPjauec2L>8w_LPUto%6u+ z<R&-OCKczVgn6+SXBI9lGQ6@bzJBQ=)x(owB$i9+J;*&NZ8UMkOo1;U6Q)b+NiDkZ zC)<Xtu_t8GWY4VuUK=KS-qclCIai$Pnhy7SmV0Jh`}peX<yGfv#R&PwYX}}W=MmSr zw*75w!Lim4WleJ@)vT|bGUY^`*Y9@MUfZcElg}0v`0?nosV^4Q{J3a}DC<rgtrHRT zQw1e_Vnn(d*Q?kxb_N`Z|J<Ic>=CTz&F%TRNg=Sgt)unUh0~#FuEn$LR~XjEuDCpR zn(A6T(`4?f3y&jKt_t5c#Vt8Q(MtG!M%+!=sk}bMT3<T+;(C{*tx*%*wW*Dpe{Z0{ zQ=J2q=Q=+WzFra}w0ixTPrWOqIfe$$<M2P#ThBRj;~nqy!rq>`yNza5`o6u8+_s&u zD`L?kft4qVKE*p*M$TwEezVT7zWh#Nkyp@<+WbDhp1n#vQI*qwFFmy@)TQdzqdL_K zG3#!x(RK~jx$vPVBf9=sO5!4IGyTf;@V(3YZ-yAmZ@+bY|L@M(+f8elW~|gJFgvqp zWm-$tjpMELeojg2-v=0VBy&1TZI`Z_z5U_$%7gNso9v&n)}Q76cW=RK2gw|vb$W4& zSF8H_)-gJZZPh;^U6JFMe|iFU)ATAq%SPn}!CL~CayWe6p9(2Y?2%lOm-#}i*!F2% zV-V-U&IpsOtUR*+*t6T__LeD{CJG$SE1R-*$GHn`-mA1e)kj}+pLys>-S-bV9;Fk* zj|Ou#hx=da-8o&_!{=CO^7^;)`&Czq%&t^p+%)Zx^-1m5%$2PicY^AkT?x`MP|4l) z$*uU~?XAs$c~2JFN<45&E7b_^JkvROb34C>fM3YGOvyt()Kb!&7DUz-=qY@ZE?vTO zVWtCj%9T@+OZ3hCl^E*RJ14ZHcUNzl#p`M7)#bjW@gyh9!>W_V-B(R^W<4!bsLk<3 zYlq<a+Xu}>i@a|Ae79nW!E=F}g_F-e%&?46Nh-b4bUbydV%r&oRD<N>94%`$OZRL( z!rSuclHu{^FZX?%l)XycNBGx@b4^L_JD7hi)>V%<tFTR-+kNws8NPovgw+2$6Y=gP z&(&{h7B+?b-u<ZY=MG~=FUOYwCR3Yk9DHERxwZM29qV?6$hRKGfmy3#w&WiQy1m8Y zNPc+P?%8+yCq9xpKOu<k`#F=$CEqUY6gqyTK8-PR<BOk@x|B@kpP$ZrWcjxQJMsAs zCN5d`ooD0kbq{BMSu1R8_2!$Wc$ana)w248C+BWIUb3y==gaAL>Q<UAyBIskFVKAH zM&Y<gY3uLxzILA`X2SJVHOl0E?lmWiG~RCq?RKUeo${$GIPd9|X}xAyr$S;p1Fh$N zKIc-z_3`ASlOa#|)TW$#A?u^H_|s;sZ&N#GdA+e;-duGu*XwSZOI20XIm;_5i_>Q< z{McJ)y|n)Ex1A+6c55F;cU+t)d)Kwc>CCE&dxhh}UwA6C-+!@r;&0pSw^Prry<X|& zlD2=BXs!Q>Z^eQKIjp|viY$$kt;*T>_PCwRCgZ-qT*aMMn|z;E+k_o2Zn9XMxiwNe zNbp_Bp}8zs>vt}&4sGAAlEr)Jl3nT5`j-k{PTo9a5F5bYRUd7|5Y2d3aF^__jE%>t zX9a{e?znX_ay@7F^~5%d)vIQ_`jk6IY9arDvRQYdOSBSfG8{j#9-mYg{$rg5-&L(m zUUH7N7`N@%y>oMcksN<Bi+1N8zP@FvYs1d(?w;DPvpP|SU32e+>%t+sSFqH5IVZk4 zucg~8Va^ray}>&@UOi8$xB0)!X_qH!c~Hul62Y4Ev-%hO_dW<&ZmFoZRUv<C->NNh zW~FadY?e}#I3sJyW3<3G*Dv$ir7d=jT$6UbxH9LihKo+qoi8rg+9G`$lhZdy%el^Y zA!_TazB)IWale*Rk>L%A>Z3=l{+d@{t$C|Pc>a_?gWl^oO)ppRpD{3%WT{^@qif16 zrZ;>KI*nWZT}*N_yXQaU;FmcWtv`Qz`xPG4<E=b6Q$~44qkHy(UptMXwp_a>8F>53 zjg{=lo!j(woygh7`*`jRg(_nirbVIuPbkb4-M(N=Y=#fl72PjOnl|72$HgcW^z~Yx zm}ijAjH9NrLnH*hB(i1}-c(n*_UDl=NBxSty@u~J*3Lb^#b;vKF2<R(N$D?#n|Q06 z8S`H4zWiPMO0x4-^X(~EmTvEH=$lPK&9`RR?RI=?4orKf5Y(f{ovwcDgFP>g3UgrN zu^#PxzC}r`n}ysH<Yw$~f6(V2uO9V)^N#)k7r*phTbIn`y7B(e66O%~Et418&bTJy za&>k6&XX5z$*wOc_-hts<WgI$tZ^&LBr7jP#pTLvq0A4TQaq|AYiV4mnkIj=XYFB0 z>o1)G%05Nx5mPkxJvyq)J6BeF$LCH#&bO@G@AH&cVs>v#(%2)bxpdp?@J~SxR_fjn zv3Q#v8KltK&vs3WXK}<b(+wQI_Hl{2K8(!_o-(h5sh-XJi=IjCMI*_misJW9-3zAs z?Rc_F@~qNshvwO-zs(L-*lRG}sCcZ-d80yGoAFLTqy}SHi)1=G*L%<Ku#n`H3uSeq z`1Uu6+boVd+-b7^+Jz+F4WfJgor~t$;BomT-?g|HuSSu3t8WyxS^NJ!xNCEY*qQ^k zRbD$h{N_=XY_p+WXM^<XicS0YL@TmWWHs(wi>aP;nMttd^HTdGl5A}M{;rZ+wIr&~ z-6MS8sr$9JPpnvP?UOg9Wd6~a+me4yOg!H0eDcT4*v$o>7Aih={J%A9@rfS-8v}Rc zv{uH|$IXfMZs+FMTz1QT<CR+3bZ_l#U%G!X_8)(7ZvXeRm=N<-GwSrFC45TQQ(>QI zQ>T@{XYoJdP~4Hq&&wk0(!QPhv8Zm}e||*M$F-oT`2bJ*0UpNo13XOa2Y8s<5Ad+G zAK+naKfuG*et?I){QwU~`vD%#_5(az?FV?c+Yj)7)bq9<;NfdOz{B5ufJdPH0FPk% z0Un|D13bd*2Y5u<5AcY#AK(#dKfoj2et<`!{Q!?-`vD%Q_5(c9?FV>d+7Ix^wjbb; zYd^pv-+q8cq5S}lV*3FerS=0n%IybuRN4>lsJ0*AQENZIquzdiN2C1!k7oM;9<BBR zJlgFCcy!ti@aVQ5;L&S8z@y)OfXATy0FPn&0Uo3F13bp<2Y5`{5Ac|_AK)=-Kfq(& zet^fK{Q!?;`vD%S_5(cD?FV>l+7Iy9wjbcJYd^qa-+q9{q5S}lWBUOfr}hIp&g}<y zT-p!txV9hQace)o<KBLN$D{oKk7xS<9<TNTJl^dGczoIq@c6bL;PGoez~kS3fG42+ z08e210iK}t13ba)2Y5o-5AcMxAK(dVKfn{-et;*U{Qys7`vIP)_5(c8?FV>b+7Ix= zwjbb$Yd^pf-+q86q5S|)V*3G}r1k?m$?XStQrZvjq_!X6NozmAliq%SC!_rUPiFf8 zo~-r*JlX9Bcyihg@Z`21;K^%0z?0v8fTy7S08e520iL4v13bm;2Y5=_5Ac+>AK)o# zKfqJoet@T<{Qys8`vIP+_5(cC?FV>j+7Iy5wjbcBYd^qK-+q9nq5S|)WBUP~ruG9o z&Fu$xTG|irw6-7MX=^{g)82l7r=$G<PiOl9p04%-Jl*XFczW6o@btDH;OT2Wz|-G; zfM-Jc0iKEN2Y4p6AK;nXet>66`vIP*?FV?KwIAS_-hO~*M*9Jtne7L7X0;#SncaSX zXHNS8p1JJ@c;>Yq;F;fkfM-Gb0iK2J2Y42>AK+Qset>65`vIP%?FV?4wIASF-hO~* zMf(AsmF)+3R<$4CS>1ksXHEM7p0(`<c-FNa;91{(fM-Md0iKQR2Y5EMAK=;Cet>67 z`vIP<?FV?awIATw-hO~*NBaSuo$UvBc5Oevv%6Ih)NtR(dzFQkkx7IZ-UFB}`)s!~ z69dC^t7p3v6x|D&f;H67HKZ^wFid4&U=U_NfF+HW*%_uEd%jx%WYY9`&vr{QXK*k~ z-}7uYq`&a;`EF_689EFMDTyVC`Xz}KnbQNG?-l_YS~~e>hhixMLvs!{gAl?@h9!-$ zx(w49pYN7e^eAWw{v)-(u2q<UVRAPEgE)$^b&D9LJ3im7q6qJFa4;~;PhvnbEn^$Q z^tR`_<(XXznx@Bsyr#8(@_8Q7>HW`k^MKub>*;Q3!A%c#LtL4WUtEw`l9)4H@WpNk Mwn<NSGcYg!0GRypGynhq delta 3450 zcmex%{yppZ3-4KZ1H748L>L$tI2byMYq#%zu=^?tH%uh4xOOA&1$O2uxsA={``gX; zGq#)WXKFX!&)jalpQYV=KWn@Beztb={p{`L`#IXp_j9(J@8@ba-_PA{zMrSvd_Qly z`F_53^Zoqo=KBTO&G!qoo9`EDH{UPZZoXfn-F&}jyZL^xcJuw>?dJO>+RgV%wwv#l zYB%36-EO{LrrmtMY`gh>xpwpY^6lpP723`BE4G{OS86xkuiS3FU!~oAziPYrezkV< z{p#)J`!(9l_iMJB@7HQK->=<nzF()^e7|nH`F_22^Zokm=KBrW&G#F&o9{PjH{WmE zZoc27-F&}kyZL^zcJuw_?dJO}+RgV{wwv#_YB%3+-EO|$rrmtMZM*q?yLR*a_U-2T z9oo(JJGPtecWO7^@7!*_-=*DrziYesez$h>{qF7N`#svt_j|UR@Aqmq-|yXSzTc<a ze7|qI`F_84^Zowq=KBNM&G!eko9_>5H{T!JZoWUH-F$y&yZQdGcJuw=?dJO<+RgVz zwwv#dYB%2>-EO`=rrmshY`gjXxOVgX@$KgO6WY!9C$^jKPii;cpWJS~Kc(G#e`>q= z{<L=U{ps!I`!m|j_h+`7@6T#C-=E!XzCWkke1C4c`To3i^Zohl=KBlU&G#3!o9{1b zH{V~}Zoa>y-F$y(yZQdIcJuw^?dJO{+RgV@wwv#-YB%3s-EO|WrrmshZM*sYx_0yZ z_3h^S8`{nHH@2JaZ)!K+-`sA#zop%Le`~w>{<e1W{q61M`#ajr_jk6N@9%0i-{0MC zzQ3p4e1C7d`To9k^Zotp=KCkKoA004ZoYp~yZQdf?dJQZw43jr+HSsoTD$rF>FwtG zXSAE|pV@A{e^$Hs{@Ly3`{%Tq@1NUlzJFf3`TqIs=KB}4o9|!PZoYp}yZQdb?dJQJ zw43i=+HSsoS-biE<?ZJCSG1e&U)gTHe^tBr{?+Z~``5Ia?_b+)zJFc2`Tq6o=KD9a zoA2M)ZoYq0yZQdj?dJQpw43kW+HSsoTf6!G?d|6KceI=D-`Q@yf7f>N{kvNgnLCSX zH<xngsxX@u*KW?z^($qzEUw+Wd(n1b&MUc%!G)W*C!L;Ly-mLUY~LLXfql<4FT{K^ zihBD#H9l3eH)HlYw`6Jk2lFIdE5GZ0kl$u_W6MH!FTR^8FO9wiio}O9e%r3kVf%E= zku5_0{}<$c*ngp6=@W~|Q$n3*3n?o--tunN{Bv7Q39g%<_~%Y{SCeO=&FWuISp1`= z*=h-@ND8rhe$#cXM|y7e!t(mR8lw7=)1Dl%I9tHz?|pAVr_0%v1<RZE+$p#y(E2R? zak{_>4^_*VCEROSxGnQnOH99Bxj5nMrjNY`F1ua6{_y^N_mj)HrXNY>R&BZ=c~X00 z<c+*$W4)H^9AzG#7JqXLQrfNfcD|VD_mD|GW~WbD8t$1frzdR5k?ycCktds8DAXs1 zx+NGD2I{EooAiNOr$y+W>Hm#6ESaZLvlTOA1yr&=2y85nuG)EOO@`lR`K&YX;_tq= zot_jnebz=+lMNPO>t2=WJiRF4)^YBL`Rwp-ZKv3#Ns3mvsrS1T8;4Eiij6$n%6?B{ zMv>?Po98Y+Hk1ZwHeL-6`<WJUdci79{}%Q7qsLm#M7(Rg-nq18TCC*RJt5W+XFALn z1ieFqluRT1{$*>Zb7$?4e=Gj^J)gCXpC?!So3j@?H-@(aYQN{&AF?spb;)tD{BWkj zPpi)C3NAUy#LF*#_G`MIm`Lo#%G2yNua?>!UFk62@$u{Ve~%o#J<HDVT!^@ZblA-+ z#(h#Z9@ht{D{eF|(~vx}u|=~tyY2q7<R7*7J{;fk@O<Ur_|I+o&3{-qu=yRjdenPY z)Tb*?#1%HJ*rxDaU?b~Z-C#YBxL2;vSQ;79GMv3OD;m}YhumFim^Nuw>VnxXs=EF% ztrZFq-YAi^`ooJa|2HJwxcEXKNSLki%`=CaU9Y)vQ-ZEete@}t-j%ES_u=Nv3Rhj@ zMEz8G0&QnX7VDc$m~rN{+v>OR{i>@)W>u;(ZkqPYd{X!|`6r@A1>V1B279kQ5U}Ou zE$QB)=d&9l^Pa5a^(jzZ%(W_1Df!WpHwT%8T7*N>qm~)iwZ75eFj>WZok?hau)8M1 zF_s90#bxIf)a<O}Vyvi_5mL6feP{F69FJF#K^BXAef$(SuSeV21<aMvikQX`?Ra37 zfZLu#_WL2HG&99_FK#+C$9Yq$-@b;$PYWh;glaS&cipPker7?6LGf{phBd3DPeeXq z3#xn>@%VGm{fZM?uP~o(v0J@4W23avg1RkNSyqc4`Ie-(#jDTqO8eAP>Gj8o`VFt# z*xbey_165T?CZjG#wkoe0cKN~A`(8$WQkI^x%@zxL#Xv3pO9Ot1l~q(xUx-0=;7_q zyYJk-Cr{b&P&DY&%X^V(y%*$SwF`azmWxdg>yFc3sT1`0XN{>tO-A{(;*X4NBBiw{ z9l5*P@;`0&J7aL;?c_X>@|LSo2Ae+Ce`8sx8&P*mwLB*4q)C;>CDm8UjLMq4FTJ^1 zT~K@eTGKVB+E%f%<(F>>Hb!b!7x=7;G;G}?G$q=yG-a`;Z{X3XB3q3u-s`9=*RqQ= zoXHiF;5d^fE+J4-bpF+i>wT_-B>lIQz2co${K`!A(&@{^=NGC;+Fs`A_3t^ZDZB0b z)B49B6I<mK-Q2~hx+GX<TwVNIVc)w8lU?}sUrt~Atvav&^t{_EKKFE-sVlYqH814t zbG4?4C2wwPcx|>Wxu15szqT?ZU2gF)r=rr-*{Av|RvgN1ve=yYbPe}Pj&in#eoR}} z?_6NMN-y8@iqNG=^LM@5U!wKp?9EdP^fDNh>&kLP*E8mF*z(nGjd)!>%OSjTht<i* z_nh6=6I(5IubN@>)6A#$f_Q@U`FE0Yxn}T3O{i(ym~v<8KJ6ao?Obanxd)Unh?nQT z(`o8aKcsNbOKH9Hf|phEzkYt`q?+(HPU5tqSN@l;3|BR^8ts2Q=?*nN<T~5H?-$$a z5WNY%$_?+;2W!^fUD9kGl(MFT@lV28{e<Ol1+Om85ttr5;d@kiXq4aD%{Qkw_D|q> z)+fg%`C-|%WhQG)-uh2yO3|}0O*iL~>{c+}d*K#eYx27rk`;#^a|rKWly0!eYMbbX za90mKw!V$!J~4~mGj8u*wATEi=foz@xXrA~cJoCfW*KsX7*5qRt!Mqf_d+<G>)XX7 zH`9CiQyQ!MwjBENSatc1gz0U29{BW0o@rRDs`2*ATOZyz*Noz<?|+kKezkq0Yj97V z(eb>%E5_4ZFW-9I;D1*-<#^WYcguLX&a|ZL-`Dp*_m$53NS3IA*yE=r@Lhd4OXYQU z!tJT+en05?_)I@u=8?Ewz4`j1^>4y&2_N)Zy==piN2YU~eP$NjQro_cE%b2U*}B+V zh3y%yE+0FQ^Wu|pZRE><+|OI8RpLSo)IN6AO;%<-J==Y=`=-aMBrjid-F|+4g9%f~ z49~DD49~vozq#$ki>4}egGj%{-rt@7J~G}={d-nym#dA;0#y!`jU37^d)D6g*ry;_ zKg(Tq|Any5efxbT9r}KCTR7AEj!6zXXK4B5#PLph-1d6nF`FF|N-u2i4lbzRi^=+> zRnB$a-cCI>u;foz!|AWT51(Q=ZPk{2V{s(Qzu3=R37>nbCt9cLc*aP+_m^9>cK633 z6~^6n3Umecsc-I|eSLkV#ikFh`1N$M`quots_9anBFZ>VY+u6n>#Ck%S=GPwS+BER zD_+>B>9i;IadUM<O?pO`%bM58#_xQV798Iow9tA<uU}2ygzB25x8{Y~^)byAePt_p zzD~O>yoB#s?U%xUhDHB;tak;!nZL;SY})gECGTVxuZlda5#-YQVAh3)>@UA@+nxTH zmXX7Ci^Eg&e*NP&u{+($YUFl2Q^-88{m$kdzr_96<8NO-k9=9iANKm^ftk1cm0p@n zRm$DEKKIS{e5+*zM=XvTzWe+Bfwjl$y8ka)47&8!?Qp;F(D6iy?7EYc0*{~Z{&*k$ zb+J{;)W@Qirmp`~e&5XK+5X)=d0R^6AD!8k{&QmD@n+}KKW4^mE=>QlQ1Nm5zwEVc zlRqjZEiSd>`n>Po9?R|9`1BRd9LxNFB<OqT9Non?lEgnV%Y8pMxBmUviom-sl72-t zo#cPq{ayIT<K;}xxSq@V-dJ5|w`ZlEywPsMeJkwm|7S<^4Qz^QCqH1H#GF`MJ9(+j zE?!0^5oUO=VfvcqyQQb^dAgf>`t=99`KCua+sy+~Iep>N-O@}CSf=lLx?2I%iI|@K ze77{yW}fLC&vz>-b{5wLN39WN`K`>r;C_dJL7V{rmNZ)H?Vj%NY`2Pzd2wxU>Exds zilqz;%{kl*LI_n1OB%C-b~B_TmL%$z6eVWnO%HgsTSU>axHkBY)B?L!VFrfD-RK7I zIlg;3%k$kTh*sQm!{@t2naghPp6>a4w>+~=aqaYnPj^c*7e3iN`2)KM*gu|$#kIj9 zWy<@eF)%QkW?*0tL~#b&Q<yUnD>CC#lJv5Q^Ya3{S=m6UI2kw@{FoRRPCwnvz`y_i Dx6xPC diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py new file mode 100644 index 00000000..0dbf50c8 --- /dev/null +++ b/data/unified_datasets/multiwoz21/database.py @@ -0,0 +1,107 @@ +import json +import os +import random +from fuzzywuzzy import fuzz +from itertools import chain +from zipfile import ZipFile +from copy import deepcopy + + +class Database: + def __init__(self): + """extract data.zip and load the database.""" + archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip')) + domains = ['restaurant', 'hotel', 'attraction', 'train', 'hospital', 'police'] + self.dbs = {} + for domain in domains: + with archive.open('data/{}_db.json'.format(domain)) as f: + self.dbs[domain] = json.loads(f.read()) + # add some missing information + self.dbs['taxi'] = { + "taxi_colors": ["black","white","red","yellow","blue","grey"], + "taxi_types": ["toyota","skoda","bmw","honda","ford","audi","lexus","volvo","volkswagen","tesla"], + "taxi_phone": ["^[0-9]{10}$"] + } + self.dbs['police'][0]['postcode'] = "cb11jg" + for entity in self.dbs['hospital']: + entity['postcode'] = "cb20qq" + entity['address'] = "Hills Rd, Cambridge" + + self.dbattr2slot = { + 'openhours': 'open hours', + 'pricerange': 'price range', + 'arriveBy': 'arrive by', + 'leaveAt': 'leave at' + } + + def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60): + """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.""" + # query the db + if domain == 'taxi': + return [{'taxi_colors': random.choice(self.dbs[domain]['taxi_colors']), + 'taxi_types': random.choice(self.dbs[domain]['taxi_types']), + 'taxi_phone': ''.join([str(random.randint(1, 9)) for _ in range(11)])}] + if domain == 'police': + return deepcopy(self.dbs['police']) + if domain == 'hospital': + department = None + for key, val in state: + if key == 'department': + department = val + if not department: + return deepcopy(self.dbs['hospital']) + else: + return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()] + state = list(map(lambda ele: ele if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), state)) + + found = [] + for i, record in enumerate(self.dbs[domain]): + constraints_iterator = zip(state, [False] * len(state)) + soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints)) + for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator): + if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]: + pass + else: + try: + record_keys = [self.dbattr2slot.get(k, k) for k in record] + if key.lower() not in record_keys: + continue + if key == 'leave at': + val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) + val2 = int(record['leaveAt'].split(':')[0]) * 100 + int(record['leaveAt'].split(':')[1]) + if val1 > val2: + break + elif key == 'arrive by': + val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1]) + val2 = int(record['arriveBy'].split(':')[0]) * 100 + int(record['arriveBy'].split(':')[1]) + if val1 < val2: + break + # elif ignore_open and key in ['destination', 'departure', 'name']: + elif ignore_open and key in ['destination', 'departure']: + continue + elif record[key].strip() == '?': + # '?' matches any constraint + continue + else: + if not fuzzy_match: + if val.strip().lower() != record[key].strip().lower(): + break + else: + if fuzz.partial_ratio(val.strip().lower(), record[key].strip().lower()) < fuzzy_match_ratio: + break + except: + continue + else: + res = deepcopy(record) + res['Ref'] = '{0:08d}'.format(i) + found.append(res) + if len(found) == topk: + return found + return found + + +if __name__ == '__main__': + db = Database() + res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3) + print(res, len(res)) + # print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']])) diff --git a/data/unified_datasets/multiwoz21/preprocess.py b/data/unified_datasets/multiwoz21/preprocess.py index 52c6860a..25308019 100644 --- a/data/unified_datasets/multiwoz21/preprocess.py +++ b/data/unified_datasets/multiwoz21/preprocess.py @@ -1,7 +1,7 @@ import copy import re from zipfile import ZipFile, ZIP_DEFLATED -from shutil import copy2 +from shutil import copy2, rmtree import json import os from tqdm import tqdm @@ -684,7 +684,6 @@ def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer): }) # correct some value and try to give char level span match = False - ori_value = value value = value.lower() if span and span[0] <= span[1]: # use original span annotation, but tokenizations are different @@ -813,7 +812,7 @@ def preprocess(): } for turn_id, turn in enumerate(ori_dialog['log']): - # correct some grammar error in text, mainly follow tokenization.md in MultiWOZ_2.1 + # correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1 text = turn['text'] text = re.sub(" Im ", " I'm ", text) text = re.sub(" im ", " i'm ", text) @@ -877,13 +876,15 @@ def preprocess(): dialogues = [] for split in splits: dialogues += dialogues_by_split[split] - init_ontology['binary_dialogue_acts'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in init_ontology['binary_dialogue_acts']] + init_ontology['binary_dialogue_acts'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in sorted(init_ontology['binary_dialogue_acts'])] json.dump(dialogues[:10], open(f'dummy_data.json', 'w'), indent=2) json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w'), indent=2) json.dump(init_ontology, open(f'{new_data_dir}/ontology.json', 'w'), indent=2) with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf: for filename in os.listdir(new_data_dir): zf.write(f'{new_data_dir}/{filename}') + rmtree(original_data_dir) + rmtree(new_data_dir) return dialogues, init_ontology if __name__ == '__main__': -- GitLab