From 9a486a52a027adafefa27b1d7ef6b62df3d873c8 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Thu, 2 Dec 2021 14:06:31 +0000
Subject: [PATCH] update multiwoz21 preprocess and add dbquery

---
 data/unified_datasets/multiwoz21/data.zip     | Bin 12314671 -> 12314614 bytes
 data/unified_datasets/multiwoz21/database.py  | 107 ++++++++++++++++++
 .../unified_datasets/multiwoz21/preprocess.py |   9 +-
 3 files changed, 112 insertions(+), 4 deletions(-)
 create mode 100644 data/unified_datasets/multiwoz21/database.py

diff --git a/data/unified_datasets/multiwoz21/data.zip b/data/unified_datasets/multiwoz21/data.zip
index 51ac6b5f326b2c9c6d078c433c89b9e0010301cd..e320219217e24c6cdef07de99ab0cf6af88c7c59 100644
GIT binary patch
delta 3371
zcmZ2~;XUiO^Y2-C1H748L>L$tI2aTPnl`^>xxmisUeGkX?#b>+%pL_z+dcGl>#BfQ
z(}kbz-o*`5>|W5cy&`D0Unxjx`^n?Gw+nN+7BmG*H2-UyG`;ioZu$Dtc~3M%?tTwn
zkdr3d`)udTc*(6wr<Sb>$a%~dGj+w)|90Yk4w<ptW>pFLt~A$V-!qv9Me;U(-f}zn
zTWTAwJ#yLph4deBt_@RU(kE@o(g;p+Nttsk_T%N5VU~wPjauec2L>8w_LPUto%6u+
z<R&-OCKczVgn6+SXBI9lGQ6@bzJBQ=)x(owB$i9+J;*&NZ8UMkOo1;U6Q)b+NiDkZ
zC)<Xtu_t8GWY4VuUK=KS-qclCIai$Pnhy7SmV0Jh`}peX<yGfv#R&PwYX}}W=MmSr
zw*75w!Lim4WleJ@)vT|bGUY^`*Y9@MUfZcElg}0v`0?nosV^4Q{J3a}DC<rgtrHRT
zQw1e_Vnn(d*Q?kxb_N`Z|J<Ic>=CTz&F%TRNg=Sgt)unUh0~#FuEn$LR~XjEuDCpR
zn(A6T(`4?f3y&jKt_t5c#Vt8Q(MtG!M%+!=sk}bMT3<T+;(C{*tx*%*wW*Dpe{Z0{
zQ=J2q=Q=+WzFra}w0ixTPrWOqIfe$$<M2P#ThBRj;~nqy!rq>`yNza5`o6u8+_s&u
zD`L?kft4qVKE*p*M$TwEezVT7zWh#Nkyp@<+WbDhp1n#vQI*qwFFmy@)TQdzqdL_K
zG3#!x(RK~jx$vPVBf9=sO5!4IGyTf;@V(3YZ-yAmZ@+bY|L@M(+f8elW~|gJFgvqp
zWm-$tjpMELeojg2-v=0VBy&1TZI`Z_z5U_$%7gNso9v&n)}Q76cW=RK2gw|vb$W4&
zSF8H_)-gJZZPh;^U6JFMe|iFU)ATAq%SPn}!CL~CayWe6p9(2Y?2%lOm-#}i*!F2%
zV-V-U&IpsOtUR*+*t6T__LeD{CJG$SE1R-*$GHn`-mA1e)kj}+pLys>-S-bV9;Fk*
zj|Ou#hx=da-8o&_!{=CO^7^;)`&Czq%&t^p+%)Zx^-1m5%$2PicY^AkT?x`MP|4l)
z$*uU~?XAs$c~2JFN<45&E7b_^JkvROb34C>fM3YGOvyt()Kb!&7DUz-=qY@ZE?vTO
zVWtCj%9T@+OZ3hCl^E*RJ14ZHcUNzl#p`M7)#bjW@gyh9!>W_V-B(R^W<4!bsLk<3
zYlq<a+Xu}>i@a|Ae79nW!E=F}g_F-e%&?46Nh-b4bUbydV%r&oRD<N>94%`$OZRL(
z!rSuclHu{^FZX?%l)XycNBGx@b4^L_JD7hi)>V%<tFTR-+kNws8NPovgw+2$6Y=gP
z&(&{h7B+?b-u<ZY=MG~=FUOYwCR3Yk9DHERxwZM29qV?6$hRKGfmy3#w&WiQy1m8Y
zNPc+P?%8+yCq9xpKOu<k`#F=$CEqUY6gqyTK8-PR<BOk@x|B@kpP$ZrWcjxQJMsAs
zCN5d`ooD0kbq{BMSu1R8_2!$Wc$ana)w248C+BWIUb3y==gaAL>Q<UAyBIskFVKAH
zM&Y<gY3uLxzILA`X2SJVHOl0E?lmWiG~RCq?RKUeo${$GIPd9|X}xAyr$S;p1Fh$N
zKIc-z_3`ASlOa#|)TW$#A?u^H_|s;sZ&N#GdA+e;-duGu*XwSZOI20XIm;_5i_>Q<
z{McJ)y|n)Ex1A+6c55F;cU+t)d)Kwc>CCE&dxhh}UwA6C-+!@r;&0pSw^Prry<X|&
zlD2=BXs!Q>Z^eQKIjp|viY$$kt;*T>_PCwRCgZ-qT*aMMn|z;E+k_o2Zn9XMxiwNe
zNbp_Bp}8zs>vt}&4sGAAlEr)Jl3nT5`j-k{PTo9a5F5bYRUd7|5Y2d3aF^__jE%>t
zX9a{e?znX_ay@7F^~5%d)vIQ_`jk6IY9arDvRQYdOSBSfG8{j#9-mYg{$rg5-&L(m
zUUH7N7`N@%y>oMcksN<Bi+1N8zP@FvYs1d(?w;DPvpP|SU32e+>%t+sSFqH5IVZk4
zucg~8Va^ray}>&@UOi8$xB0)!X_qH!c~Hul62Y4Ev-%hO_dW<&ZmFoZRUv<C->NNh
zW~FadY?e}#I3sJyW3<3G*Dv$ir7d=jT$6UbxH9LihKo+qoi8rg+9G`$lhZdy%el^Y
zA!_TazB)IWale*Rk>L%A>Z3=l{+d@{t$C|Pc>a_?gWl^oO)ppRpD{3%WT{^@qif16
zrZ;>KI*nWZT}*N_yXQaU;FmcWtv`Qz`xPG4<E=b6Q$~44qkHy(UptMXwp_a>8F>53
zjg{=lo!j(woygh7`*`jRg(_nirbVIuPbkb4-M(N=Y=#fl72PjOnl|72$HgcW^z~Yx
zm}ijAjH9NrLnH*hB(i1}-c(n*_UDl=NBxSty@u~J*3Lb^#b;vKF2<R(N$D?#n|Q06
z8S`H4zWiPMO0x4-^X(~EmTvEH=$lPK&9`RR?RI=?4orKf5Y(f{ovwcDgFP>g3UgrN
zu^#PxzC}r`n}ysH<Yw$~f6(V2uO9V)^N#)k7r*phTbIn`y7B(e66O%~Et418&bTJy
za&>k6&XX5z$*wOc_-hts<WgI$tZ^&LBr7jP#pTLvq0A4TQaq|AYiV4mnkIj=XYFB0
z>o1)G%05Nx5mPkxJvyq)J6BeF$LCH#&bO@G@AH&cVs>v#(%2)bxpdp?@J~SxR_fjn
zv3Q#v8KltK&vs3WXK}<b(+wQI_Hl{2K8(!_o-(h5sh-XJi=IjCMI*_misJW9-3zAs
z?Rc_F@~qNshvwO-zs(L-*lRG}sCcZ-d80yGoAFLTqy}SHi)1=G*L%<Ku#n`H3uSeq
z`1Uu6+boVd+-b7^+Jz+F4WfJgor~t$;BomT-?g|HuSSu3t8WyxS^NJ!xNCEY*qQ^k
zRbD$h{N_=XY_p+WXM^<XicS0YL@TmWWHs(wi>aP;nMttd^HTdGl5A}M{;rZ+wIr&~
z-6MS8sr$9JPpnvP?UOg9Wd6~a+me4yOg!H0eDcT4*v$o>7Aih={J%A9@rfS-8v}Rc
zv{uH|$IXfMZs+FMTz1QT<CR+3bZ_l#U%G!X_8)(7ZvXeRm=N<-GwSrFC45TQQ(>QI
zQ>T@{XYoJdP~4Hq&&wk0(!QPhv8Zm}e||*M$F-oT`2bJ*0UpNo13XOa2Y8s<5Ad+G
zAK+naKfuG*et?I){QwU~`vD%#_5(az?FV?c+Yj)7)bq9<;NfdOz{B5ufJdPH0FPk%
z0Un|D13bd*2Y5u<5AcY#AK(#dKfoj2et<`!{Q!?-`vD%Q_5(c9?FV>d+7Ix^wjbb;
zYd^pv-+q8cq5S}lV*3FerS=0n%IybuRN4>lsJ0*AQENZIquzdiN2C1!k7oM;9<BBR
zJlgFCcy!ti@aVQ5;L&S8z@y)OfXATy0FPn&0Uo3F13bp<2Y5`{5Ac|_AK)=-Kfq(&
zet^fK{Q!?;`vD%S_5(cD?FV>l+7Iy9wjbcJYd^qa-+q9{q5S}lWBUOfr}hIp&g}<y
zT-p!txV9hQace)o<KBLN$D{oKk7xS<9<TNTJl^dGczoIq@c6bL;PGoez~kS3fG42+
z08e210iK}t13ba)2Y5o-5AcMxAK(dVKfn{-et;*U{Qys7`vIP)_5(c8?FV>b+7Ix=
zwjbb$Yd^pf-+q86q5S|)V*3G}r1k?m$?XStQrZvjq_!X6NozmAliq%SC!_rUPiFf8
zo~-r*JlX9Bcyihg@Z`21;K^%0z?0v8fTy7S08e520iL4v13bm;2Y5=_5Ac+>AK)o#
zKfqJoet@T<{Qys8`vIP+_5(cC?FV>j+7Iy5wjbcBYd^qK-+q9nq5S|)WBUP~ruG9o
z&Fu$xTG|irw6-7MX=^{g)82l7r=$G<PiOl9p04%-Jl*XFczW6o@btDH;OT2Wz|-G;
zfM-Jc0iKEN2Y4p6AK;nXet>66`vIP*?FV?KwIAS_-hO~*M*9Jtne7L7X0;#SncaSX
zXHNS8p1JJ@c;>Yq;F;fkfM-Gb0iK2J2Y42>AK+Qset>65`vIP%?FV?4wIASF-hO~*
zMf(AsmF)+3R<$4CS>1ksXHEM7p0(`<c-FNa;91{(fM-Md0iKQR2Y5EMAK=;Cet>67
z`vIP<?FV?awIATw-hO~*NBaSuo$UvBc5Oevv%6Ih)NtR(dzFQkkx7IZ-UFB}`)s!~
z69dC^t7p3v6x|D&f;H67HKZ^wFid4&U=U_NfF+HW*%_uEd%jx%WYY9`&vr{QXK*k~
z-}7uYq`&a;`EF_689EFMDTyVC`Xz}KnbQNG?-l_YS~~e>hhixMLvs!{gAl?@h9!-$
zx(w49pYN7e^eAWw{v)-(u2q<UVRAPEgE)$^b&D9LJ3im7q6qJFa4;~;PhvnbEn^$Q
z^tR`_<(XXznx@Bsyr#8(@_8Q7>HW`k^MKub>*;Q3!A%c#LtL4WUtEw`l9)4H@WpNk
Mwn<NSGcYg!0GRypGynhq

delta 3450
zcmex%{yppZ3-4KZ1H748L>L$tI2byMYq#%zu=^?tH%uh4xOOA&1$O2uxsA={``gX;
zGq#)WXKFX!&)jalpQYV=KWn@Beztb={p{`L`#IXp_j9(J@8@ba-_PA{zMrSvd_Qly
z`F_53^Zoqo=KBTO&G!qoo9`EDH{UPZZoXfn-F&}jyZL^xcJuw>?dJO>+RgV%wwv#l
zYB%36-EO{LrrmtMY`gh>xpwpY^6lpP723`BE4G{OS86xkuiS3FU!~oAziPYrezkV<
z{p#)J`!(9l_iMJB@7HQK->=<nzF()^e7|nH`F_22^Zokm=KBrW&G#F&o9{PjH{WmE
zZoc27-F&}kyZL^zcJuw_?dJO}+RgV{wwv#_YB%3+-EO|$rrmtMZM*q?yLR*a_U-2T
z9oo(JJGPtecWO7^@7!*_-=*DrziYesez$h>{qF7N`#svt_j|UR@Aqmq-|yXSzTc<a
ze7|qI`F_84^Zowq=KBNM&G!eko9_>5H{T!JZoWUH-F$y&yZQdGcJuw=?dJO<+RgVz
zwwv#dYB%2>-EO`=rrmshY`gjXxOVgX@$KgO6WY!9C$^jKPii;cpWJS~Kc(G#e`>q=
z{<L=U{ps!I`!m|j_h+`7@6T#C-=E!XzCWkke1C4c`To3i^Zohl=KBlU&G#3!o9{1b
zH{V~}Zoa>y-F$y(yZQdIcJuw^?dJO{+RgV@wwv#-YB%3s-EO|WrrmshZM*sYx_0yZ
z_3h^S8`{nHH@2JaZ)!K+-`sA#zop%Le`~w>{<e1W{q61M`#ajr_jk6N@9%0i-{0MC
zzQ3p4e1C7d`To9k^Zotp=KCkKoA004ZoYp~yZQdf?dJQZw43jr+HSsoTD$rF>FwtG
zXSAE|pV@A{e^$Hs{@Ly3`{%Tq@1NUlzJFf3`TqIs=KB}4o9|!PZoYp}yZQdb?dJQJ
zw43i=+HSsoS-biE<?ZJCSG1e&U)gTHe^tBr{?+Z~``5Ia?_b+)zJFc2`Tq6o=KD9a
zoA2M)ZoYq0yZQdj?dJQpw43kW+HSsoTf6!G?d|6KceI=D-`Q@yf7f>N{kvNgnLCSX
zH<xngsxX@u*KW?z^($qzEUw+Wd(n1b&MUc%!G)W*C!L;Ly-mLUY~LLXfql<4FT{K^
zihBD#H9l3eH)HlYw`6Jk2lFIdE5GZ0kl$u_W6MH!FTR^8FO9wiio}O9e%r3kVf%E=
zku5_0{}<$c*ngp6=@W~|Q$n3*3n?o--tunN{Bv7Q39g%<_~%Y{SCeO=&FWuISp1`=
z*=h-@ND8rhe$#cXM|y7e!t(mR8lw7=)1Dl%I9tHz?|pAVr_0%v1<RZE+$p#y(E2R?
zak{_>4^_*VCEROSxGnQnOH99Bxj5nMrjNY`F1ua6{_y^N_mj)HrXNY>R&BZ=c~X00
z<c+*$W4)H^9AzG#7JqXLQrfNfcD|VD_mD|GW~WbD8t$1frzdR5k?ycCktds8DAXs1
zx+NGD2I{EooAiNOr$y+W>Hm#6ESaZLvlTOA1yr&=2y85nuG)EOO@`lR`K&YX;_tq=
zot_jnebz=+lMNPO>t2=WJiRF4)^YBL`Rwp-ZKv3#Ns3mvsrS1T8;4Eiij6$n%6?B{
zMv>?Po98Y+Hk1ZwHeL-6`<WJUdci79{}%Q7qsLm#M7(Rg-nq18TCC*RJt5W+XFALn
z1ieFqluRT1{$*>Zb7$?4e=Gj^J)gCXpC?!So3j@?H-@(aYQN{&AF?spb;)tD{BWkj
zPpi)C3NAUy#LF*#_G`MIm`Lo#%G2yNua?>!UFk62@$u{Ve~%o#J<HDVT!^@ZblA-+
z#(h#Z9@ht{D{eF|(~vx}u|=~tyY2q7<R7*7J{;fk@O<Ur_|I+o&3{-qu=yRjdenPY
z)Tb*?#1%HJ*rxDaU?b~Z-C#YBxL2;vSQ;79GMv3OD;m}YhumFim^Nuw>VnxXs=EF%
ztrZFq-YAi^`ooJa|2HJwxcEXKNSLki%`=CaU9Y)vQ-ZEete@}t-j%ES_u=Nv3Rhj@
zMEz8G0&QnX7VDc$m~rN{+v>OR{i>@)W>u;(ZkqPYd{X!|`6r@A1>V1B279kQ5U}Ou
zE$QB)=d&9l^Pa5a^(jzZ%(W_1Df!WpHwT%8T7*N>qm~)iwZ75eFj>WZok?hau)8M1
zF_s90#bxIf)a<O}Vyvi_5mL6feP{F69FJF#K^BXAef$(SuSeV21<aMvikQX`?Ra37
zfZLu#_WL2HG&99_FK#+C$9Yq$-@b;$PYWh;glaS&cipPker7?6LGf{phBd3DPeeXq
z3#xn>@%VGm{fZM?uP~o(v0J@4W23avg1RkNSyqc4`Ie-(#jDTqO8eAP>Gj8o`VFt#
z*xbey_165T?CZjG#wkoe0cKN~A`(8$WQkI^x%@zxL#Xv3pO9Ot1l~q(xUx-0=;7_q
zyYJk-Cr{b&P&DY&%X^V(y%*$SwF`azmWxdg>yFc3sT1`0XN{>tO-A{(;*X4NBBiw{
z9l5*P@;`0&J7aL;?c_X>@|LSo2Ae+Ce`8sx8&P*mwLB*4q)C;>CDm8UjLMq4FTJ^1
zT~K@eTGKVB+E%f%<(F>>Hb!b!7x=7;G;G}?G$q=yG-a`;Z{X3XB3q3u-s`9=*RqQ=
zoXHiF;5d^fE+J4-bpF+i>wT_-B>lIQz2co${K`!A(&@{^=NGC;+Fs`A_3t^ZDZB0b
z)B49B6I<mK-Q2~hx+GX<TwVNIVc)w8lU?}sUrt~Atvav&^t{_EKKFE-sVlYqH814t
zbG4?4C2wwPcx|>Wxu15szqT?ZU2gF)r=rr-*{Av|RvgN1ve=yYbPe}Pj&in#eoR}}
z?_6NMN-y8@iqNG=^LM@5U!wKp?9EdP^fDNh>&kLP*E8mF*z(nGjd)!>%OSjTht<i*
z_nh6=6I(5IubN@>)6A#$f_Q@U`FE0Yxn}T3O{i(ym~v<8KJ6ao?Obanxd)Unh?nQT
z(`o8aKcsNbOKH9Hf|phEzkYt`q?+(HPU5tqSN@l;3|BR^8ts2Q=?*nN<T~5H?-$$a
z5WNY%$_?+;2W!^fUD9kGl(MFT@lV28{e<Ol1+Om85ttr5;d@kiXq4aD%{Qkw_D|q>
z)+fg%`C-|%WhQG)-uh2yO3|}0O*iL~>{c+}d*K#eYx27rk`;#^a|rKWly0!eYMbbX
za90mKw!V$!J~4~mGj8u*wATEi=foz@xXrA~cJoCfW*KsX7*5qRt!Mqf_d+<G>)XX7
zH`9CiQyQ!MwjBENSatc1gz0U29{BW0o@rRDs`2*ATOZyz*Noz<?|+kKezkq0Yj97V
z(eb>%E5_4ZFW-9I;D1*-<#^WYcguLX&a|ZL-`Dp*_m$53NS3IA*yE=r@Lhd4OXYQU
z!tJT+en05?_)I@u=8?Ewz4`j1^>4y&2_N)Zy==piN2YU~eP$NjQro_cE%b2U*}B+V
zh3y%yE+0FQ^Wu|pZRE><+|OI8RpLSo)IN6AO;%<-J==Y=`=-aMBrjid-F|+4g9%f~
z49~DD49~vozq#$ki>4}egGj%{-rt@7J~G}={d-nym#dA;0#y!`jU37^d)D6g*ry;_
zKg(Tq|Any5efxbT9r}KCTR7AEj!6zXXK4B5#PLph-1d6nF`FF|N-u2i4lbzRi^=+>
zRnB$a-cCI>u;foz!|AWT51(Q=ZPk{2V{s(Qzu3=R37>nbCt9cLc*aP+_m^9>cK633
z6~^6n3Umecsc-I|eSLkV#ikFh`1N$M`quots_9anBFZ>VY+u6n>#Ck%S=GPwS+BER
zD_+>B>9i;IadUM<O?pO`%bM58#_xQV798Iow9tA<uU}2ygzB25x8{Y~^)byAePt_p
zzD~O>yoB#s?U%xUhDHB;tak;!nZL;SY})gECGTVxuZlda5#-YQVAh3)>@UA@+nxTH
zmXX7Ci^Eg&e*NP&u{+($YUFl2Q^-88{m$kdzr_96<8NO-k9=9iANKm^ftk1cm0p@n
zRm$DEKKIS{e5+*zM=XvTzWe+Bfwjl$y8ka)47&8!?Qp;F(D6iy?7EYc0*{~Z{&*k$
zb+J{;)W@Qirmp`~e&5XK+5X)=d0R^6AD!8k{&QmD@n+}KKW4^mE=>QlQ1Nm5zwEVc
zlRqjZEiSd>`n>Po9?R|9`1BRd9LxNFB<OqT9Non?lEgnV%Y8pMxBmUviom-sl72-t
zo#cPq{ayIT<K;}xxSq@V-dJ5|w`ZlEywPsMeJkwm|7S<^4Qz^QCqH1H#GF`MJ9(+j
zE?!0^5oUO=VfvcqyQQb^dAgf>`t=99`KCua+sy+~Iep>N-O@}CSf=lLx?2I%iI|@K
ze77{yW}fLC&vz>-b{5wLN39WN`K`>r;C_dJL7V{rmNZ)H?Vj%NY`2Pzd2wxU>Exds
zilqz;%{kl*LI_n1OB%C-b~B_TmL%$z6eVWnO%HgsTSU>axHkBY)B?L!VFrfD-RK7I
zIlg;3%k$kTh*sQm!{@t2naghPp6>a4w>+~=aqaYnPj^c*7e3iN`2)KM*gu|$#kIj9
zWy<@eF)%QkW?*0tL~#b&Q<yUnD>CC#lJv5Q^Ya3{S=m6UI2kw@{FoRRPCwnvz`y_i
Dx6xPC

diff --git a/data/unified_datasets/multiwoz21/database.py b/data/unified_datasets/multiwoz21/database.py
new file mode 100644
index 00000000..0dbf50c8
--- /dev/null
+++ b/data/unified_datasets/multiwoz21/database.py
@@ -0,0 +1,107 @@
+import json
+import os
+import random
+from fuzzywuzzy import fuzz
+from itertools import chain
+from zipfile import ZipFile
+from copy import deepcopy
+
+
+class Database:
+    def __init__(self):
+        """extract data.zip and load the database."""
+        archive = ZipFile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data.zip'))
+        domains = ['restaurant', 'hotel', 'attraction', 'train', 'hospital', 'police']
+        self.dbs = {}
+        for domain in domains:
+            with archive.open('data/{}_db.json'.format(domain)) as f:
+                self.dbs[domain] = json.loads(f.read())
+        # add some missing information
+        self.dbs['taxi'] = {
+            "taxi_colors": ["black","white","red","yellow","blue","grey"],
+            "taxi_types":  ["toyota","skoda","bmw","honda","ford","audi","lexus","volvo","volkswagen","tesla"],
+            "taxi_phone": ["^[0-9]{10}$"]
+        }
+        self.dbs['police'][0]['postcode'] = "cb11jg"
+        for entity in self.dbs['hospital']:
+            entity['postcode'] = "cb20qq"
+            entity['address'] = "Hills Rd, Cambridge"
+
+        self.dbattr2slot = {
+            'openhours': 'open hours',
+            'pricerange': 'price range',
+            'arriveBy': 'arrive by',
+            'leaveAt': 'leave at'
+        }
+
+    def query(self, domain, state, topk, ignore_open=False, soft_contraints=(), fuzzy_match_ratio=60):
+        """return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state."""
+        # query the db
+        if domain == 'taxi':
+            return [{'taxi_colors': random.choice(self.dbs[domain]['taxi_colors']),
+            'taxi_types': random.choice(self.dbs[domain]['taxi_types']),
+            'taxi_phone': ''.join([str(random.randint(1, 9)) for _ in range(11)])}]
+        if domain == 'police':
+            return deepcopy(self.dbs['police'])
+        if domain == 'hospital':
+            department = None
+            for key, val in state:
+                if key == 'department':
+                    department = val
+            if not department:
+                return deepcopy(self.dbs['hospital'])
+            else:
+                return [deepcopy(x) for x in self.dbs['hospital'] if x['department'].lower() == department.strip().lower()]
+        state = list(map(lambda ele: ele if not(ele[0] == 'area' and ele[1] == 'center') else ('area', 'centre'), state))
+
+        found = []
+        for i, record in enumerate(self.dbs[domain]):
+            constraints_iterator = zip(state, [False] * len(state))
+            soft_contraints_iterator = zip(soft_contraints, [True] * len(soft_contraints))
+            for (key, val), fuzzy_match in chain(constraints_iterator, soft_contraints_iterator):
+                if val in ["", "dont care", 'not mentioned', "don't care", "dontcare", "do n't care"]:
+                    pass
+                else:
+                    try:
+                        record_keys = [self.dbattr2slot.get(k, k) for k in record]
+                        if key.lower() not in record_keys:
+                            continue
+                        if key == 'leave at':
+                            val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1])
+                            val2 = int(record['leaveAt'].split(':')[0]) * 100 + int(record['leaveAt'].split(':')[1])
+                            if val1 > val2:
+                                break
+                        elif key == 'arrive by':
+                            val1 = int(val.split(':')[0]) * 100 + int(val.split(':')[1])
+                            val2 = int(record['arriveBy'].split(':')[0]) * 100 + int(record['arriveBy'].split(':')[1])
+                            if val1 < val2:
+                                break
+                        # elif ignore_open and key in ['destination', 'departure', 'name']:
+                        elif ignore_open and key in ['destination', 'departure']:
+                            continue
+                        elif record[key].strip() == '?':
+                            # '?' matches any constraint
+                            continue
+                        else:
+                            if not fuzzy_match:
+                                if val.strip().lower() != record[key].strip().lower():
+                                    break
+                            else:
+                                if fuzz.partial_ratio(val.strip().lower(), record[key].strip().lower()) < fuzzy_match_ratio:
+                                    break
+                    except:
+                        continue
+            else:
+                res = deepcopy(record)
+                res['Ref'] = '{0:08d}'.format(i)
+                found.append(res)
+                if len(found) == topk:
+                    return found
+        return found
+
+
+if __name__ == '__main__':
+    db = Database()
+    res = db.query("train", [['departure', 'cambridge'], ['destination','peterborough'], ['day', 'tuesday'], ['arrive by', '11:15']], topk=3)
+    print(res, len(res))
+    # print(db.query("hotel", [['price range', 'moderate'], ['stars','4'], ['type', 'guesthouse'], ['internet', 'yes'], ['parking', 'no'], ['area', 'east']]))
diff --git a/data/unified_datasets/multiwoz21/preprocess.py b/data/unified_datasets/multiwoz21/preprocess.py
index 52c6860a..25308019 100644
--- a/data/unified_datasets/multiwoz21/preprocess.py
+++ b/data/unified_datasets/multiwoz21/preprocess.py
@@ -1,7 +1,7 @@
 import copy
 import re
 from zipfile import ZipFile, ZIP_DEFLATED
-from shutil import copy2
+from shutil import copy2, rmtree
 import json
 import os
 from tqdm import tqdm
@@ -684,7 +684,6 @@ def convert_da(da_dict, utt, sent_tokenizer, word_tokenizer):
             })
             # correct some value and try to give char level span
             match = False
-            ori_value = value
             value = value.lower()
             if span and span[0] <= span[1]:
                 # use original span annotation, but tokenizations are different
@@ -813,7 +812,7 @@ def preprocess():
         }
 
         for turn_id, turn in enumerate(ori_dialog['log']):
-            # correct some grammar error in text, mainly follow tokenization.md in MultiWOZ_2.1
+            # correct some grammar errors in the text, mainly following `tokenization.md` in MultiWOZ_2.1
             text = turn['text']
             text = re.sub(" Im ", " I'm ", text)
             text = re.sub(" im ", " i'm ", text)
@@ -877,13 +876,15 @@ def preprocess():
     dialogues = []
     for split in splits:
         dialogues += dialogues_by_split[split]
-    init_ontology['binary_dialogue_acts'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in init_ontology['binary_dialogue_acts']]
+    init_ontology['binary_dialogue_acts'] = [{'intent':bda[0],'domain':bda[1],'slot':bda[2],'value':bda[3]} for bda in sorted(init_ontology['binary_dialogue_acts'])]
     json.dump(dialogues[:10], open(f'dummy_data.json', 'w'), indent=2)
     json.dump(dialogues, open(f'{new_data_dir}/dialogues.json', 'w'), indent=2)
     json.dump(init_ontology, open(f'{new_data_dir}/ontology.json', 'w'), indent=2)
     with ZipFile('data.zip', 'w', ZIP_DEFLATED) as zf:
         for filename in os.listdir(new_data_dir):
             zf.write(f'{new_data_dir}/{filename}')
+    rmtree(original_data_dir)
+    rmtree(new_data_dir)
     return dialogues, init_ontology
 
 if __name__ == '__main__':
-- 
GitLab