diff --git a/DO.example.dst b/DO.example.dst
new file mode 100644
index 0000000000000000000000000000000000000000..db50985905f5ef13ca37fa610f8d329fa7341ac4
--- /dev/null
+++ b/DO.example.dst
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+# Parameters ------------------------------------------------------
+
+# --- MultiWOZ 2.1 dataset (noisy data)
+TRAIN_TASK="multiwoz21"
+TRAIN_DATA_DIR="data/multiwoz/data/MultiWOZ_2.1"
+# --- MultiWOZ 2.4 dataset (cleaned data, uses legacy format)
+TEST_TASK="multiwoz21_legacy"
+TEST_DATA_DIR="data/MultiWOZ2.4/data/MULTIWOZ2.4"
+
+# Project paths etc. ----------------------------------------------
+
+DATASET_CONFIG="trippy/dataset_config/multiwoz21.json"
+OUT_DIR=results
+mkdir -p ${OUT_DIR}
+
+# Main ------------------------------------------------------------
+
+for step in train dev test; do
+    args_add=""
+    if [ "$step" = "train" ]; then
+	TASK=${TRAIN_TASK}
+	DATA_DIR=${TRAIN_DATA_DIR}
+	args_add="--do_train"
+    elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+	TASK=${TEST_TASK}
+	DATA_DIR=${TEST_DATA_DIR}
+	args_add="--do_eval"
+    fi
+
+    python3 ${TOOLS_DIR}/storm_dst.py \
+	    --task_name=${TASK} \
+	    --model_type="roberta" \
+	    --model_name_or_path="roberta-base" \
+	    --dataset_config=${DATASET_CONFIG} \
+	    --data_dir=${DATA_DIR} \
+	    --predict_type=${step} \
+	    --do_lower_case \
+	    --learning_rate=1e-4 \
+	    --num_train_epochs=10 \
+	    --max_seq_length=180 \
+	    --train_batch_size=48 \
+	    --eval_batch_size=48 \
+	    --output_dir=${OUT_DIR} \
+	    --eval_all_checkpoints \
+	    --save_epochs=1 \
+	    --warmup_proportion=0.1 \
+	    --local_files_only \
+	    --optimizer="Adam" \
+	    ${args_add} \
+	    2>&1 | tee ${OUT_DIR}.${x}/${step}.log
+
+    if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
+	python3 trippy/metric_dst.py \
+		--dataset_config=${DATASET_CONFIG} \
+		--file_list="${OUT_DIR}/pred_res.${step}*json" \
+		2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
+    fi
+done
diff --git a/DO.example.glue b/DO.example.glue
new file mode 100644
index 0000000000000000000000000000000000000000..25233bd96252f5d46b52978d4fb3b33bc4630042
--- /dev/null
+++ b/DO.example.glue
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Parameters ------------------------------------------------------
+
+# --- CoLA dataset
+#TASK="cola"
+#DATA_DIR="data/glue_data/CoLA"
+# --- MRPC dataset
+#TASK="mrpc"
+#DATA_DIR="data/glue_data/MRPC"
+# --- RTE dataset
+#TASK="rte"
+#DATA_DIR="data/glue_data/RTE"
+
+# Project paths etc. ----------------------------------------------
+
+OUT_DIR=results
+mkdir -p ${OUT_DIR}
+
+# Main ------------------------------------------------------------
+
+python3 storm.py \
+	--task_name=${TASK} \
+	--train_file=${DATA_DIR}/train_corrupt_30.tsv \
+	--validation_file=${DATA_DIR}/valid.tsv \
+	--test_file=${DATA_DIR}/test.tsv \
+	--model_type="roberta" \
+	--model_name_or_path="roberta-base" \
+	--train_batch_size=32 \
+	--eval_batch_size=32 \
+	--learning_rate=5e-2 \
+	--num_train_epochs=10 \
+	--output_dir=${OUT_DIR} \
+	--local_files_only \
+	--evaluate_rescaling \
+	--save_checkpoints \
+	--save_stats \
+	2>&1 | tee ${OUT_DIR}/train.log
diff --git a/DO.example.spam b/DO.example.spam
new file mode 100644
index 0000000000000000000000000000000000000000..e08878cf94d4e9fbe039a5fe60819aa6b4d0ce79
--- /dev/null
+++ b/DO.example.spam
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Parameters ------------------------------------------------------
+
+# --- SMS dataset
+#TASK="sms"
+#DATA_DIR="data/sms"
+# --- Youtube dataset
+TASK="youtube"
+DATA_DIR="data/youtube"
+
+# Project paths etc. ----------------------------------------------
+
+FEATS_DIR="${DATA_DIR}/tfidf_feats"
+OUT_DIR=results
+mkdir -p ${OUT_DIR}
+
+# Main ------------------------------------------------------------
+
+python3 storm.py \
+	--task_name=${TASK} \
+	--train_file=${DATA_DIR}/train_corrupt_30.tsv \
+	--validation_file=${DATA_DIR}/valid.tsv \
+	--test_file=${DATA_DIR}/test.tsv \
+	--model_type="roberta" \
+	--model_name_or_path="roberta-base" \
+	--train_batch_size=32 \
+	--eval_batch_size=32 \
+	--learning_rate=5e-2 \
+	--num_train_epochs=10 \
+	--output_dir=${OUT_DIR} \
+	--local_files_only \
+	--use_tfidf \
+	--tfidf_path=${FEATS_DIR} \
+	--evaluate_rescaling \
+	--save_checkpoints \
+	--save_stats \
+	2>&1 | tee ${OUT_DIR}/train.log
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e38336ac6d02e67cfae1945e10f3a99772ecae2c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "{}"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright 2020 Heinrich Heine University Duesseldorf
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/README.md b/README.md
index f43034405afef48a3ff2a7c869d8ef4a3fa37bb4..e7261401cd4a9d4329baa173eec821d0a4c5a8f8 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,85 @@
-# STORM - Public
+## Introduction
 
-This is the public repository for STORM, to be published at AAAI'25 as "Learning from Noisy Labels via Self-Taught On-the-Fly Meta Loss Rescaling".
+STORM (Self-Taught On-the-fly Rescaling via Meta loss) is a flexible loss rescaling method for learning from noisy labels.
 
-The code will be made available upon presentation at the conference early 2025.
\ No newline at end of file
+STORM
+- uses a novel meta learning scheme called meta loss rescaling that eliminates the need for clean validation data by rescaling both the loss in the inner loop and the meta loss in the outer loop, using noisy validation data.
+- is flexible as it dynamically decides how much importance to assign to a sample at each training stage and keeps learning from the model's own signals.
+- is efficient as it uses features based on sample losses and prediction probabilities instead of sample gradients, reducing computational complexity.
+- is robust as it handles class imbalance, different types of noise, and prevents overfitting.
+
+## Recent updates
+
+- 2025.02.24: Initial commit
+
+## How to run
+
+### Preparations
+
+- Clone the TripPy source code:
+
+```
+git clone https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public.git trippy
+```
+
+- Download and prepare the datasets. See data/README.md for instructions.
+
+### Run
+
+Scripts are provided for demonstrating how to use STORM.
+
+`DO.example.*` will train and evaluate a model with settings that were used for experiments in our paper "Learning from Noisy Labels via Self-Taught On-the-Fly Meta Loss Rescaling".
+- `DO.example.spam` applies STORM to TF-IDF encoded Youtube and SMS datasets.
+- `DO.example.glue` applies STORM to CoLA, MRPC and RTE datasets from the GLUE benchmark.
+- `DO.example.dst` applies STORM to dialogue state tracking on the MultiWOZ 2.4 dataset.
+
+Using the command line parameter `--simulate_only` to storm.py or storm_dst.py will recreate the baselines without applying STORM.
+Note that this will deactivate the loss rescaling, but not the training of the loss rescaler. The loss rescaler merely remains unused.
+
+## Datasets
+
+STORM is not limited to a particular set of datasets. In the paper, we evaluated STORM on the following datasets:
+- Youtube (https://doi.org/10.24432/C58885)
+- SMS (https://doi.org/10.24432/C5CC84)
+- CoLA (https://gluebenchmark.com/tasks)
+- MRPC (https://gluebenchmark.com/tasks)
+- RTE (https://gluebenchmark.com/tasks)
+- MultiWOZ 2.1 (https://github.com/budzianowski/multiwoz.git)
+- MultiWOZ 2.4 (https://github.com/smartyfh/MultiWOZ2.4.git)
+
+The ```--task_name``` is
+- 'youtube', for Youtube
+- 'sms', for SMS
+- 'cola', for CoLA
+- 'mrpc', for MRPC
+- 'rte', for RTE
+- 'multiwoz21', for MultiWOZ 2.1
+- 'multiwoz21_legacy', for MultiWOZ 2.4
+
+## Requirements
+
+- torch (tested: 2.0.0)
+- transformers (tested: 4.18.0)
+- tensorboardX (tested: 2.1)
+
+## Citation
+
+This work is published as [Learning from Noisy Labels via Self-Taught On-the-Fly Meta Loss Rescaling](https://arxiv.org/abs/2412.12955)
+
+If you use STORM in your own work, please cite our work as follows:
+
+```
+@inproceedings{heck2024storm,
+    title = "Learning from Noisy Labels via Self-Taught On-the-Fly Meta Loss Rescaling",
+    author = "Heck, Michael and Geishauser, Christian and Lubis, Nurul and van Niekerk, Carel and 
+    	      Feng, Shutong and Lin, Hsien-Chin and Ruppik, Benjamin Matthias and Vukovic, Renato and
+              Ga{\v{s}}i{\'c}, Milica",
+    booktitle = "Proceedings of the AAAI Conference on Artificial Intelligence",
+    month = "Mar.",
+    year = "2025",
+    volume = "39",
+    address = "Philadelphia, Pennsylvania, USA",
+    publisher = "AAAI Press, Washington, DC, USA",
+    organization = "Association for the Advancement of Artificial Intelligence"
+}
+```
diff --git a/agra.py b/agra.py
new file mode 100644
index 0000000000000000000000000000000000000000..0940657512ade7ba9eaf610b53f489ad426426d0
--- /dev/null
+++ b/agra.py
@@ -0,0 +1,158 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of AGRA
+# (arXiv:2306.04502)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import autograd_hacks
+import torch
+
+import numpy as np
+
+from torch.utils.data import (DataLoader, WeightedRandomSampler)
+
+from utils_gpu import (to_device)
+
+
+class F1Loss:
+    def __init__(self, num_classes):
+        self.num_classes = num_classes
+
+    def __call__(self, predictions, labels):
+        softmax = torch.nn.Softmax(dim=1)
+        all_preds = softmax(predictions)
+
+        if self.num_classes == 2:
+            preds = all_preds[:, 1]
+
+            tp = torch.sum(preds * labels)
+            fp = torch.sum(preds * (1 - labels))
+            fn = torch.sum((1 - preds) * labels)
+
+            f1_loss = 1 - ( (2 * tp) / (2 * tp + fn + fp + 1e-8) ) # tp + fn = number positive labels
+        elif self.num_classes > 2:
+            f1 = torch.zeros(self.num_classes)
+
+            for label in range(0, self.num_classes):
+                labels_bin = copy.deepcopy(labels)
+                labels_bin = torch.where(labels_bin == label, 1, 0)
+
+                tp = torch.sum(all_preds[:, label] * labels_bin)
+                fp = torch.sum(all_preds[:, label] * (1 - labels_bin))
+                fn = torch.sum((1 - all_preds[:, label]) * labels_bin)
+
+                f1[label] = (2 * tp) / (2 * tp + fn + fp + 1e-8)
+
+            f1_loss = 1 - torch.mean(f1) # define loss as 1 - macro F1
+        else:
+            raise ValueError("Invalid number of classes")
+
+        return f1_loss
+
+
+class AGRA:
+    def __init__(self, comp_loss, num_classes, is_weighted, dataset, model, device, window_size=1):
+        self.stats = {}
+        self.window_size = window_size # window_size=0 -> unlimited window size
+        self.comp_loss = comp_loss
+        self.num_classes = num_classes
+        self.model = model
+        self.device = device
+        self.dataset = dataset
+
+        autograd_hacks.add_hooks(self.model)
+
+        self.agra_weights = torch.ones(len(self.dataset))
+        if is_weighted:
+            self.agra_weights = torch.tensor(self._compute_weights([i['labels'] for i in self.dataset]))
+
+
+    def _compute_weights(self, train_labels):
+        num_samples = len(train_labels)
+        _, counts = np.unique(train_labels, return_counts=True)
+        assert sum(counts) == num_samples
+        weights = np.zeros(num_samples)
+        for label in range(0, self.num_classes):
+            weights[np.array(train_labels) == label] = 1 / counts[label]
+        return weights
+
+
+    def _get_loss(self):
+        if self.comp_loss == 'F1':
+            comp_loss = F1Loss(self.num_classes)
+            loss_type = 'sum'
+        else:
+            comp_loss = torch.nn.CrossEntropyLoss(reduction='mean')
+            loss_type = 'mean'
+        return comp_loss, loss_type
+
+
+    def _get_comp_grads(self):
+        comp_grads = [params[1].grad.reshape(-1).detach().clone().cpu()
+                      for params in self.model.named_parameters()
+                      if hasattr(params[1], 'grad1') and 'bias' not in params[0]]
+        comp_grads = torch.cat(comp_grads).numpy()
+        if 'comp_grads' not in self.stats:
+            self.stats['comp_grads'] = []
+        self.stats['comp_grads'].append(comp_grads)
+        self.stats['comp_grads'] = self.stats['comp_grads'][-1 * self.window_size:]
+        return np.mean(self.stats['comp_grads'], 0)
+
+
+    def build_dataloader(self, data_collator, batch_size):
+        comp_sampler = WeightedRandomSampler(self.agra_weights, len(self.dataset))
+        self.comp_dataloader = DataLoader(
+            self.dataset, collate_fn=data_collator, batch_size=batch_size, sampler=comp_sampler)
+        self.comp_dataloader = iter(self.comp_dataloader)
+
+
+    def agra_step(self, batch, agra_layer_groups=['classifier']):
+        agra_crit, agra_loss_type = self._get_loss()
+        batch_size = batch['input_ids'].size(0)
+
+        # Get comparison gradients
+        comp_batch = next(self.comp_dataloader)
+        autograd_hacks.clear_grad1(self.model)
+        self.model.classifier.weight.retain_grad() # required for _get_comp_grads()
+        comp_batch = to_device(comp_batch, self.device)
+        comp_outputs = self.model(**comp_batch, suppress_dropout_passes=True)
+        comp_loss = agra_crit(comp_outputs[1][0], comp_batch['labels']) # Logits vs. labels
+        comp_loss.backward()
+        autograd_hacks.compute_grad1(self.model, loss_type=agra_loss_type, layer_groups=agra_layer_groups)
+        comp_grads = self._get_comp_grads()
+
+        del self.model.classifier.weight.grad
+        autograd_hacks.clear_grad1(self.model)
+
+        # Get sample gradients
+        outputs = self.model(**batch)
+        labels = batch['labels']
+        sample_loss = agra_crit(outputs[1][0], labels) # Logits vs. labels
+        sample_loss.backward()
+        autograd_hacks.compute_grad1(self.model, loss_type=agra_loss_type, layer_groups=agra_layer_groups)
+        grads = [params[1].grad1.detach().clone().cpu() for params in self.model.named_parameters() if
+                 hasattr(params[1], 'grad1') and 'bias' not in params[0]]
+
+        # Get gradient scores
+        grad_scores = np.zeros(batch_size)
+        for l_itr in range(batch_size):
+            sample_grads = torch.cat([grad[l_itr].reshape(-1) for grad in grads]).numpy()
+            grad_scores[l_itr] = np.sum(sample_grads * comp_grads) / ((np.linalg.norm(sample_grads) * np.linalg.norm(comp_grads)) + 1e-8)
+
+        autograd_hacks.clear_grad1(self.model)
+
+        return torch.tensor(grad_scores)
+    
diff --git a/autograd_hacks.py b/autograd_hacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..686ad651b0b162280823c68e794e8c485a15a92e
--- /dev/null
+++ b/autograd_hacks.py
@@ -0,0 +1,325 @@
+# coding=utf-8
+#
+# Code retrieved from https://github.com/cybertronai/autograd-hacks
+# and modified for our purposes.
+
+
+"""
+Library for extracting interesting quantites from autograd
+
+Not thread-safe because of module-level variables
+
+Notation:
+o: number of output classes (exact Hessian), number of Hessian samples (sampled Hessian)
+n: batch-size
+do: output dimension (output channels for convolution)
+di: input dimension (input channels for convolution)
+Hi: per-example Hessian of matmul, shaped as matrix of [dim, dim], indices have been row-vectorized
+Hi_bias: per-example Hessian of bias
+Oh, Ow: output height, output width (convolution)
+Kh, Kw: kernel height, kernel width (convolution)
+
+Jb: batch output Jacobian of matmul, output sensitivity for example,class pair, [o, n, ....]
+Jb_bias: as above, but for bias
+
+A, activations: inputs into current layer
+B, backprops: backprop values (aka Lop aka Jacobian-vector product) observed at current layer
+"""
+
+import torch
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import List
+
+_supported_layers = ['Linear', 'Conv2d', \
+                     'InnerFunctionalLinear']  # Supported layer class types
+_hooks_disabled: bool = False           # work-around for https://github.com/pytorch/pytorch/issues/25723
+_enforce_fresh_backprop: bool = False   # global switch to catch double backprop errors on Hessian computation
+
+
+def add_hooks(model: nn.Module) -> None:
+    """
+    Adds hooks to model to save activations and backprop values.
+
+    The hooks will
+    1. save activations into param.activations during forward pass
+    2. append backprops to params.backprops_list during backward pass.
+
+    Call "remove_hooks(model)" to disable this.
+
+    Args:
+        model:
+    """
+
+    global _hooks_disabled
+    _hooks_disabled = False
+
+    handles = []
+    for layer in model.modules():
+        if _layer_type(layer) in _supported_layers:
+            handles.append(layer.register_forward_hook(_capture_activations))
+            handles.append(layer.register_full_backward_hook(_capture_backprops))
+
+    model.__dict__.setdefault('autograd_hacks_hooks', []).extend(handles)
+
+
+def remove_hooks(model: nn.Module) -> None:
+    """
+    Remove hooks added by add_hooks(model)
+    """
+
+    assert model == 0, "not working, remove this after fix to https://github.com/pytorch/pytorch/issues/25723"
+
+    if not hasattr(model, 'autograd_hacks_hooks'):
+        print("Warning, asked to remove hooks, but no hooks found")
+    else:
+        for handle in model.autograd_hacks_hooks:
+            handle.remove()
+        del model.autograd_hacks_hooks
+
+
+def disable_hooks() -> None:
+    """
+    Globally disable all hooks installed by this library.
+    """
+
+    global _hooks_disabled
+    _hooks_disabled = True
+
+
+def enable_hooks() -> None:
+    """the opposite of disable_hooks()"""
+
+    global _hooks_disabled
+    _hooks_disabled = False
+
+
+def is_supported(layer: nn.Module) -> bool:
+    """Check if this layer is supported"""
+
+    return _layer_type(layer) in _supported_layers
+
+
+def _layer_type(layer: nn.Module) -> str:
+    return layer.__class__.__name__
+
+
+def _capture_activations(layer: nn.Module, input: List[torch.Tensor], output: torch.Tensor):
+    """Save activations into layer.activations in forward pass"""
+
+    if _hooks_disabled:
+        return
+    assert _layer_type(layer) in _supported_layers, "Hook installed on unsupported layer, this shouldn't happen"
+    setattr(layer, "activations", input[0].detach())
+
+
+def _capture_backprops(layer: nn.Module, _input, output):
+    """Append backprop to layer.backprops_list in backward pass."""
+    global _enforce_fresh_backprop
+
+    if _hooks_disabled:
+        return
+
+    if _enforce_fresh_backprop:
+        assert not hasattr(layer, 'backprops_list'), "Seeing result of previous backprop, use clear_backprops(model) to clear"
+        _enforce_fresh_backprop = False
+
+    if not hasattr(layer, 'backprops_list'):
+        setattr(layer, 'backprops_list', [])
+    layer.backprops_list.append(output[0].detach())
+
+
+def clear_backprops(model: nn.Module) -> None:
+    """Delete layer.backprops_list in every layer."""
+    for layer in model.modules():
+        if hasattr(layer, 'backprops_list'):
+            del layer.backprops_list
+
+
+def clear_grad1(model: nn.Module) -> None:
+    """Delete attr related to grad1 in every layer."""
+    for layer in model.modules():
+        if hasattr(layer, 'backprops_list'):
+            del layer.backprops_list
+        if hasattr(layer, 'activations'):
+            del layer.activations
+        if hasattr(layer, 'weight') and hasattr(layer.weight, 'grad1'):
+            del layer.weight.grad1
+        if hasattr(layer, 'bias') and hasattr(layer.bias, 'grad1'):
+            del layer.bias.grad1
+
+
+def compute_grad1(model: nn.Module, loss_type: str = 'mean', layer_groups = ['classifier']) -> None:
+    """
+    Compute per-example gradients and save them under 'param.grad1'. Must be called after loss.backprop()
+
+    Args:
+        model:
+        loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch
+    """
+
+    assert loss_type in ('sum', 'mean')
+    for module in model.named_modules():
+        layer_name = module[0]
+        layer = module[1]
+        layer_type = _layer_type(layer)
+        if layer_type not in _supported_layers:
+            continue
+        skip = False
+        if len(layer_groups) > 0:
+            skip = True
+            for lg in layer_groups:
+                if lg in layer_name:
+                    skip = False
+                    break
+        if skip:
+            continue
+        assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
+        assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"
+        assert len(layer.backprops_list) == 1, "Multiple backprops detected, make sure to call clear_backprops(model)"
+
+        A = layer.activations
+        n = A.shape[0]
+        if loss_type == 'mean':
+            B = layer.backprops_list[0] * n
+        else:  # loss_type == 'sum':
+            B = layer.backprops_list[0]
+
+        if layer_type in ['Linear', 'InnerFunctionalLinear']:
+            if len(B.size()) == 2:
+                setattr(layer.weight, 'grad1', torch.einsum('ni,nj->nij', B, A))
+            elif len(B.size()) == 3:
+                setattr(layer.weight, 'grad1', torch.einsum('nij,nik->njk', B, A))
+            else:
+                pass
+            if layer.bias is not None:
+                if len(B.size()) == 2:
+                    setattr(layer.bias, 'grad1', B)
+                elif len(B.size()) == 3:
+                    setattr(layer.bias, 'grad1', torch.einsum('nij->nj', B))
+                else:
+                    pass
+        elif layer_type == 'Conv2d':
+            A = torch.nn.functional.unfold(A, layer.kernel_size)
+            B = B.reshape(n, -1, A.shape[-1])
+            grad1 = torch.einsum('ijk,ilk->ijl', B, A)
+            shape = [n] + list(layer.weight.shape)
+            setattr(layer.weight, 'grad1', grad1.reshape(shape))
+            if layer.bias is not None:
+                setattr(layer.bias, 'grad1', torch.sum(B, dim=2))
+
+
+def compute_hess(model: nn.Module,) -> None:
+    """Save Hessian under param.hess for each param in the model"""
+
+    for layer in model.modules():
+        layer_type = _layer_type(layer)
+        if layer_type not in _supported_layers:
+            continue
+        assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
+        assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"
+
+        if layer_type in ['Linear', 'InnerFunctionalLinear']:
+            A = layer.activations
+            B = torch.stack(layer.backprops_list)
+
+            n = A.shape[0]
+            o = B.shape[0]
+
+            A = torch.stack([A] * o)
+            Jb = torch.einsum("oni,onj->onij", B, A).reshape(n*o,  -1)
+            H = torch.einsum('ni,nj->ij', Jb, Jb) / n
+
+            setattr(layer.weight, 'hess', H)
+
+            if layer.bias is not None:
+                setattr(layer.bias, 'hess', torch.einsum('oni,onj->ij', B, B)/n)
+
+        elif layer_type == 'Conv2d':
+            Kh, Kw = layer.kernel_size
+            di, do = layer.in_channels, layer.out_channels
+
+            A = layer.activations.detach()
+            A = torch.nn.functional.unfold(A, (Kh, Kw))       # n, di * Kh * Kw, Oh * Ow
+            n = A.shape[0]
+            B = torch.stack([Bt.reshape(n, do, -1) for Bt in layer.backprops_list])  # o, n, do, Oh*Ow
+            o = B.shape[0]
+
+            A = torch.stack([A] * o)                          # o, n, di * Kh * Kw, Oh*Ow
+            Jb = torch.einsum('onij,onkj->onik', B, A)        # o, n, do, di * Kh * Kw
+
+            Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb)     # n, do, di*Kh*Kw, do, di*Kh*Kw
+            Jb_bias = torch.einsum('onij->oni', B)
+            Hi_bias = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias)
+
+            setattr(layer.weight, 'hess', Hi.mean(dim=0))
+            if layer.bias is not None:
+                setattr(layer.bias, 'hess', Hi_bias.mean(dim=0))
+
+
+def backprop_hess(output: torch.Tensor, hess_type: str) -> None:
+    """
+    Call backprop 1 or more times to get values needed for Hessian computation.
+
+    Args:
+        output: prediction of neural network (ie, input of nn.CrossEntropyLoss())
+        hess_type: type of Hessian propagation, "CrossEntropy" results in exact Hessian for CrossEntropy
+
+    Returns:
+
+    """
+
+    assert hess_type in ('LeastSquares', 'CrossEntropy')
+    global _enforce_fresh_backprop
+    n, o = output.shape
+
+    _enforce_fresh_backprop = True
+
+    if hess_type == 'CrossEntropy':
+        batch = F.softmax(output, dim=1)
+
+        mask = torch.eye(o).expand(n, o, o)
+        diag_part = batch.unsqueeze(2).expand(n, o, o) * mask
+        outer_prod_part = torch.einsum('ij,ik->ijk', batch, batch)
+        hess = diag_part - outer_prod_part
+        assert hess.shape == (n, o, o)
+
+        for i in range(n):
+            hess[i, :, :] = symsqrt(hess[i, :, :])
+        hess = hess.transpose(0, 1)
+
+    elif hess_type == 'LeastSquares':
+        hess = []
+        assert len(output.shape) == 2
+        batch_size, output_size = output.shape
+
+        id_mat = torch.eye(output_size)
+        for out_idx in range(output_size):
+            hess.append(torch.stack([id_mat[out_idx]] * batch_size))
+
+    for o in range(o):
+        output.backward(hess[o], retain_graph=True)
+
+
+def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32):
+    """Symmetric square root of a positive semi-definite matrix.
+    See https://github.com/pytorch/pytorch/issues/25481"""
+
+    s, u = torch.symeig(a, eigenvectors=True)
+    cond_dict = {torch.float32: 1e3 * 1.1920929e-07, torch.float64: 1E6 * 2.220446049250313e-16}
+
+    if cond in [None, -1]:
+        cond = cond_dict[dtype]
+
+    above_cutoff = (abs(s) > cond * torch.max(abs(s)))
+
+    psigma_diag = torch.sqrt(s[above_cutoff])
+    u = u[:, above_cutoff]
+
+    B = u @ torch.diag(psigma_diag) @ u.t()
+    if return_rank:
+        return B, len(psigma_diag)
+    else:
+        return B
diff --git a/data/README.md b/data/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d2d7035776034ec072e73a30a3ae35dfd79c3b1e
--- /dev/null
+++ b/data/README.md
@@ -0,0 +1,69 @@
+## Supported datasets
+
+Datasets should go into the ```data/``` folder.
+
+### Youtube:
+
+The original URL (http://dcomp.sor.ufscar.br/talmeida/youtubespamcollection) is not active anymore.
+The original files are archived at (https://doi.org/10.24432/C58885).
+
+For easy reproducibility, we provide our re-formatted files in ```data/youtube```.
+
+### SMS
+
+The original URL (http://www.dt.fee.unicamp.br/~tiago/smsspamcollection) is not active anymore.
+The original files are archived at (https://doi.org/10.24432/C5CC84).
+
+For easy reproducibility, we provide our re-formatted files in ```data/sms```.
+
+### GLUE
+
+The GLUE tasks and links to their files are listed on (https://gluebenchmark.com/tasks).
+
+```
+wget https://dl.fbaipublicfiles.com/glue/data/CoLA.zip
+unzip CoLA.zip -d CoLA
+wget https://dl.fbaipublicfiles.com/glue/data/RTE.zip
+unzip RTE -d RTE
+mkdir MRPC
+python download_mrpc.py
+```
+
+### MultiWOZ 2.1
+
+```
+git clone https://github.com/budzianowski/multiwoz.git
+unzip multiwoz/data/MultiWOZ_2.1.zip -d multiwoz/data/
+python split_multiwoz_data.py --data_dir multiwoz/data/MultiWOZ_2.1
+```
+
+### MultiWOZ 2.4
+
+```
+git clone https://github.com/smartyfh/MultiWOZ2.4.git
+unzip MultiWOZ2.4/data/MULTIWOZ2.4.zip -d MultiWOZ2.4/data/
+python split_multiwoz_data.py --data_dir MultiWOZ2.4/data/MULTIWOZ2.4
+```
+
+## Add symmetric noise
+
+```
+python corrupt_data.py --file_in youtube/train.tsv --file_out youtube/train_corrupt_30.tsv --task youtube --corruption 0.3 --seed 42
+python corrupt_data.py --file_in sms/train.tsv --file_out sms/train_corrupt_30.tsv --task sms --corruption 0.3 --seed 42
+python corrupt_data.py --file_in CoLA/train.tsv --file_out CoLA/train_corrupt_30.tsv --task cola --corruption 0.3 --seed 42
+python corrupt_data.py --file_in MRPC/train.tsv --file_out MRPC/train_corrupt_30.tsv --task mrpc --corruption 0.3 --seed 42
+python corrupt_data.py --file_in RTE/train.tsv --file_out RTE/train_corrupt_30.tsv --task rte --corruption 0.3 --seed 42
+```
+
+## TF-IDF features
+
+The TF-IDF features for the Youtube and SMS datasets used for our experiments were
+computed with the help of the AGRA code base (https://github.com/anasedova/AGRA).
+
+For easy reproducibility, we provide our computed TF-IDF features as pickle files, found in the
+respective data folders.
+
+```
+gunzip youtube/tfidf_feats/*.gz
+gunzip sms/tfidf_feats/*.gz
+```
diff --git a/data/corrupt_data.py b/data/corrupt_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8a9a376f0b4c7e91307570260a3048faa733541
--- /dev/null
+++ b/data/corrupt_data.py
@@ -0,0 +1,84 @@
+# coding=utf-8
+
+import argparse
+import json
+import numpy as np
+from transformers import (set_seed)
+
+
+DATA_SPECS = {'youtube': {'has_header': True, 'label': 2, 'sentence1': 1, 'sentence2': None, 'num_labels': 2, 'label_map': None},
+              'sms': {'has_header': True, 'label': 2, 'sentence1': 1, 'sentence2': None, 'num_labels': 2, 'label_map': None},
+              'cola': {'has_header': False, 'label': 1, 'sentence1': 3, 'sentence2': None, 'num_labels': 2, 'label_map': None},
+              'mrpc': {'has_header': True, 'label': 0, 'sentence1': 3, 'sentence2': 4, 'num_labels': 2, 'label_map': None},
+              'rte': {'has_header': True, 'label': 3, 'sentence1': 1, 'sentence2': 2, 'num_labels': 2, 'label_map': {'not_entailment': 0, 'entailment': 1, 0: 'not_entailment', 1: 'entailment'}},
+}
+
+
+def load_raw_dataset(input_file, data_specs):
+    raw_data = []
+    with open(input_file, "r", encoding='utf-8') as f:
+        header = None
+        for l_itr, line in enumerate(f):
+            if data_specs['has_header'] and l_itr == 0:
+                header = line.strip()
+                continue
+            raw_data_point = line.strip().split('\t')
+            raw_data.append(raw_data_point)
+    return raw_data, header
+
+
+def corrupt_data(dataset, corruption_rate, data_specs):
+    def label_map(label, lbl_map):
+        if lbl_map is not None:
+            return lbl_map[label]
+        else:
+            return label
+
+    is_corrupted = []
+    corruption_prob = np.random.random_sample(len(dataset))
+    for i_itr, i in enumerate(dataset):
+        label = int(label_map(dataset[i_itr][data_specs['label']], data_specs['label_map']))
+        if corruption_prob[i_itr] <= corruption_rate:
+            if data_specs['num_labels'] == 1:
+                raise NotImplementedError()
+            elif data_specs['num_labels'] == 2:
+                rn = int(not label) # Flips a binary value
+            else:
+                rn = np.random.choice([e for e in range(data_specs['num_labels']) if e != label])
+            dataset[i_itr][data_specs['label']] = str(label_map(rn, data_specs['label_map']))
+            is_corrupted.append(True)
+        else:
+            is_corrupted.append(False)
+    return is_corrupted
+
+
+def main():
+    parser = argparse.ArgumentParser(description="...")
+    
+    # Required parameters
+    parser.add_argument("--file_in", type=str, default=None, required=True, help="")
+    parser.add_argument("--file_out", type=str, default=None, required=True, help="")
+    parser.add_argument("--task", type=str, default=None, required=True, help="")
+    parser.add_argument("--corruption", type=float, default=0.0, required=True, help="")
+    parser.add_argument("--seed", type=int, default=42, help="")
+
+    args = parser.parse_args()
+
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    raw_data, header = load_raw_dataset(args.file_in, DATA_SPECS[args.task])
+
+    is_corrupted = corrupt_data(raw_data, args.corruption, data_specs=DATA_SPECS[args.task])
+
+    if header is not None:
+        header += '\tCorrupted'    
+    with open(args.file_out, "w") as f:
+        if header is not None:
+            f.write(header + '\n')
+        for i_itr, i in enumerate(raw_data):
+            f.write('\t'.join(i + [str(is_corrupted[i_itr])]) + '\n')
+
+
+if __name__ == "__main__":
+    main()
diff --git a/data/download_mrpc.py b/data/download_mrpc.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bf3e44d0a60492a904a0eec30a19e45a33aba3
--- /dev/null
+++ b/data/download_mrpc.py
@@ -0,0 +1,59 @@
+# coding=utf-8
+#
+# Adopted from https://github.com/nyu-mll/jiant/blob/master/scripts/download_glue_data.py
+# (not accessible anymore as of 2024)
+
+import argparse
+import os
+import urllib.request
+
+MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
+MRPC_DEV = 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc'
+MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
+
+
+def format_mrpc(data_dir):
+    mrpc_train_file = os.path.join(data_dir, "msr_paraphrase_train.txt")
+    mrpc_test_file = os.path.join(data_dir, "msr_paraphrase_test.txt")
+    urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
+    urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
+    urllib.request.urlretrieve(MRPC_DEV, os.path.join(data_dir, "dev_ids.tsv"))
+
+    dev_ids = []
+    with open(os.path.join(data_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
+        for row in ids_fh:
+            dev_ids.append(row.strip().split('\t'))
+
+    with open(mrpc_train_file, encoding="utf8") as data_fh, \
+         open(os.path.join(data_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
+         open(os.path.join(data_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
+        header = data_fh.readline()
+        train_fh.write(header)
+        dev_fh.write(header)
+        for row in data_fh:
+            label, id1, id2, s1, s2 = row.strip().split('\t')
+            if [id1, id2] in dev_ids:
+                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
+            else:
+                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
+
+    with open(mrpc_test_file, encoding="utf8") as data_fh, \
+            open(os.path.join(data_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
+        header = data_fh.readline()
+        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
+        for idx, row in enumerate(data_fh):
+            label, id1, id2, s1, s2 = row.strip().split('\t')
+            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
+
+
+def main():
+    parser = argparse.ArgumentParser("Download and prepare MRPC.")
+    parser.add_argument('--data_dir', type=str, default="MRPC",
+                        help='directory to save data to')
+    args = parser.parse_args()
+
+    format_mrpc(args.data_dir)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/data/sms/test.tsv b/data/sms/test.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..06bf06527b1dc0b0d0578da6681d0bac9a80c85e
--- /dev/null
+++ b/data/sms/test.tsv
@@ -0,0 +1,501 @@
+index	sentence1	label
+0	eh u send wrongly lar...	0
+1	tmr timin still da same wat cos i got lesson until 6...	0
+2	ok but tell me half an hr b4 u come i need 2 prepare.	0
+3	oh ok no prob..	0
+4	great. i was getting worried about you. just know that a wonderful and caring person like you will have only the best in life. know that u r wonderful and god's love is yours.	0
+5	take some small dose tablet for fever	0
+6	but pls dont play in others life.	0
+7	host-based idps for linux systems.	0
+8	so what about you. what do you remember	0
+9	if we win its really no 1 side for long time.	0
+10	babe: u want me dont u baby! im nasty and have a thing 4 filthyguys. fancy a rude time with a sexy bitch. how about we go slo n hard! txt xxx slo(4msgs)	1
+11	i cant pick the phone right now. pls send a message	0
+12	when/where do i pick you up	0
+13	k:)k:)what are detail you want to transfer?acc no enough?	0
+14	hey, iouri gave me your number, i'm wylie, ryan's friend	0
+15	wat r u doing now?	0
+16	i'm already back home so no probably not	0
+17	anything is valuable in only 2 situations: first- before getting it... second- after loosing it...	0
+18	it has everything to do with the weather. keep extra warm. its a cold but nothing serious. pls lots of vitamin c	0
+19	no problem baby. is this is a good time to talk? i called and left a message.	0
+20	haha... can... but i'm having dinner with my cousin...	0
+21	when can ?_ come out?	0
+22	swhrt how u dey,hope ur ok, tot about u 2day.love n miss.take care.	0
+23	short but cute: \be a good person	0
+24	i'm at home. please call	0
+25	ok. not sure what time tho as not sure if can get to library before class. will try. see you at some point! have good eve.	0
+26	pls pls find out from aunt nike.	0
+27	freemsg: claim ur 250 sms messages-text ok to 84025 now!use web2mobile 2 ur mates etc. join txt250.com for 1.50p/wk. t&c box139, la32wu. 16 . remove txtx or stop	1
+28	sorry that took so long, omw now	0
+29	hey loverboy! i love you !! i had to tell ... i look at your picture and ache to feel you between my legs ... fuck i want you ... i need you ... i crave you .	0
+30	sorry, i'll call later  &lt;#&gt; mins	0
+31	ya tel, wats ur problem..	0
+32	ok lor... sony ericsson salesman... i ask shuhui then she say quite gd 2 use so i considering...	0
+33	we have pizza if u want	0
+34	i know she called me	0
+35	yes..he is really great..bhaji told kallis best cricketer after sachin in world:).very tough to get out.	0
+36	y ?_ wan to go there? c doctor?	0
+37	i reach home safe n sound liao...	0
+38	the gas station is like a block away from my house, you'll drive right by it since armenia ends at swann and you have to take howard	0
+39	i??m cool ta luv but v.tired 2 cause i have been doin loads of planning all wk, we have got our social services inspection at the nursery! take care & spk sn x.	0
+40	i'm not smoking while people use \wylie smokes too much\" to justify ruining my shit"	0
+41	ok but knackered. just came home and went to sleep! not good at this full time work lark.	0
+42	2marrow only. wed at  &lt;#&gt;  to 2 aha.	0
+43	simply sitting and watching match in office..	0
+44	25p 4 alfie moon's children in need song on ur mob. tell ur m8s. txt tone charity to 8007 for nokias or poly charity for polys: zed 08701417012 profit 2 charity.	1
+45	sure, if i get an acknowledgement from you that it's astoundingly tactless and generally faggy to demand a blood oath fo	0
+46	its sarcasm.. .nt scarcasim	0
+47	are u coming to the funeral home	0
+48	night sweet, sleep well! i've just been to see the exorcism of emily rose and may never sleep again! hugs and snogs!	0
+49	k i'll take care of it	0
+50	carlos'll be here in a minute if you still need to buy	0
+51	22 days to kick off! for euro2004 u will be kept up to date with the latest news and results daily. to be removed send get txt stop to 83222	1
+52	fine am simply sitting.	0
+53	hmmm.. thk sure got time to hop ard... ya, can go 4 free abt... muz call u to discuss liao...	0
+54	it's fine, imma get a drink or somethin. want me to come find you?	0
+55	do you want a new nokia 3510i colour phone deliveredtomorrow? with 300 free minutes to any mobile + 100 free texts + free camcorder reply or call 08000930705	1
+56	u 2.	0
+57	am watching house ??? very entertaining ??? am getting the whole hugh laurie thing ??? even with the stick ??? indeed especially with the stick.	0
+58	we tried to contact you re your reply to our offer of a video handset? 750 anytime networks mins? unlimited text? camcorder? reply or call 08000930705 now	1
+59	?? collecting ur laptop then going to configure da settings izzit?	0
+60	just trying to figure out when i'm suppose to see a couple different people this week. we said we'd get together but i didn't set dates	0
+61	your daily text from me ??? a favour this time	0
+62	but am going to college pa. what to do. are else ill come there it self. pa.	0
+63	hey i booked the kb on sat already... what other lessons are we going for ah? keep your sat night free we need to meet and confirm our lodging	0
+64	network operator. the service is free. for t & c's visit 80488.biz	1
+65	i am in your office na.	0
+66	it's that time of the week again, ryan	0
+67	you will be in the place of that man	0
+68	please protect yourself from e-threats. sib never asks for sensitive information like passwords,atm/sms pin thru email. never share your password with anybody.	0
+69	lol your right. what diet? everyday i cheat anyway. i'm meant to be a fatty :(	0
+70	collect your valentine's weekend to paris inc flight & hotel + ??200 prize guaranteed! text: paris to no: 69101. www.rtf.sphosting.com	1
+71	i'm at home n ready...	0
+72	glad it went well :) come over at 11 then we'll have plenty of time before claire goes to work.	0
+73	when u wana see it then	0
+74	nope. i just forgot. will show next week	0
+75	i know you are thinkin malaria. but relax, children cant handle malaria. she would have been worse and its gastroenteritis. if she takes enough to replace her loss her temp will reduce. and if you give her malaria meds now she will just vomit. its a self limiting illness she has which means in a few days it will completely stop	0
+76	urgent! please call 09066612661 from your landline, your complimentary 4* lux costa del sol holiday or ??1000 cash await collection. ppm 150 sae t&cs james 28, eh74rr	1
+77	how much would it cost to hire a hitman	0
+78	yar... i tot u knew dis would happen long ago already.	0
+79	sorry that was my uncle. i.ll keep in touch	0
+80	had your mobile 10 mths? update to latest orange camera/video phones for free. save ??s with free texts/weekend calls. text yes for a callback orno to opt out	1
+81	thanks for understanding. i've been trying to tell sura that.	0
+82	now that you have started dont stop. just pray for more good ideas and anything i see that can help you guys i.ll forward you a link.	0
+83	welcome to uk-mobile-date this msg is free giving you free calling to 08719839835. future mgs billed at 150p daily. to cancel send \go stop\" to 89123"	1
+84	we're all getting worried over here, derek and taylor have already assumed the worst	0
+85	which is why i never wanted to tell you any of this. which is why i'm so short with you and on-edge as of late.	0
+86	some friends want me to drive em someplace, probably take a while	0
+87	\aww you must be nearly dead!well jez iscoming over todo some workand that whilltake forever!\""	0
+88	oh. u must have taken your real valentine out shopping first.	0
+89	every king was once a crying baby and every great building was once a map.. not imprtant where u r today, but where u wil reach tomorw. gud ni8	0
+90	hi. || do u want | to join me with sts later? || meeting them at five. || call u after class.	0
+91	indians r poor but india is not a poor country. says one of the swiss bank directors. he says that \ &lt;#&gt;  lac crore\" of indian money is deposited in swiss banks which can be used for 'taxless' budget for  &lt;#&gt;  yrs. can give  &lt;#&gt;  crore jobs to all indians. from any village to delhi 4 lane roads. forever free power suply to more than  &lt;#&gt;  social projects. every citizen can get monthly  &lt;#&gt; /- for  &lt;#&gt;  yrs. no need of world bank &amp; imf loan. think how our money is blocked by rich politicians. we have full rights against corrupt politicians. itna forward karo ki pura india padhe.g.m.\""	0
+92	just finished eating. got u a plate. not leftovers this time.	0
+93	ello babe u ok?	0
+94	don't worry, * is easy once have ingredients!	0
+95	sir, waiting for your letter.	0
+96	but i'm surprised she still can guess right lor...	0
+97	you should get more chicken broth if you want ramen unless there's some i don't know about	0
+98	same. wana plan a trip sometme then	0
+99	i walked an hour 2 c u! doesn??t that show i care y wont u believe im serious?	0
+100	ur cash-balance is currently 500 pounds - to maximize ur cash-in now send go to 86688 only 150p/meg. cc: 08718720201 hg/suite342/2lands row/w1j6hl	1
+101	good afternoon my boytoy. how goes that walking here and there day ? did you get that police abstract? are you still out and about? i wake and miss you babe	0
+102	your b4u voucher w/c 27/03 is marsms. log onto www.b4utele.com for discount credit. to opt out reply stop. customer care call 08717168528	1
+103	i am waiting for your call sir.	0
+104	oops i was in the shower when u called. hey a parking garage collapsed at university hospital. see i'm not crazy. stuff like that does happen.	0
+105	we are both fine. thanks	0
+106	do u ever get a song stuck in your head for no reason and it won't go away til u listen to it like 5 times?	0
+107	thanks for yesterday sir. you have been wonderful. hope you enjoyed the burial. mojibiola	0
+108	hi dear call me its urgnt. i don't know whats your problem. you don't want to work or if you have any other problem at least tell me. wating for your reply.	0
+109	sounds better than my evening im just doing my costume. im not sure what time i finish tomorrow but i will txt you at the end.	0
+110	hack chat. get backdoor entry into 121 chat rooms at a fraction of the cost. reply neo69 or call 09050280520, to subscribe 25p pm. dps, bcm box 8027 ldn, wc1n3xx	1
+111	only just got this message, not ignoring you. yes, i was. shopping that is	0
+112	k actually can you guys meet me at the sunoco on howard? it should be right on the way	0
+113	thats cool. i liked your photos. you are very sexy!	0
+114	no, its true..k,do u knw dis no. &lt;#&gt; ?	0
+115	i told your number to gautham..	0
+116	somewhere out there beneath the pale moon light someone think in of u some where out there where dreams come true... goodnite &amp; sweet dreams	0
+117	hey next sun 1030 there's a basic yoga course... at bugis... we can go for that... pilates intro next sat.... tell me what time you r free	0
+118	aight we can pick some up, you open before tonight?	0
+119	my slave! i want you to take 2 or 3 pictures of yourself today in bright light on your cell phone! bright light!	0
+120	boy; i love u grl: hogolo boy: gold chain kodstini grl: agalla boy: necklace madstini grl: agalla boy: hogli 1 mutai eerulli kodthini! grl: i love u kano;-)	0
+121	ha ha ha good joke. girls are situation seekers.	0
+122	lovely smell on this bus and it ain't tobacco...	0
+123	someone u know has asked our dating service 2 contact you! cant guess who? call 09058097189 now all will be revealed. pobox 6, ls15hb 150p	1
+124	ma head dey swell oh. thanks for making my day	0
+125	wat uniform? in where get?	0
+126	e admin building there? i might b slightly earlier... i'll call u when i'm reaching...	0
+127	okie	0
+128	our ride equally uneventful - not too many of those pesky cyclists around at that time of night ;).	0
+129	do you want a new nokia 3510i colour phone deliveredtomorrow? with 300 free minutes to any mobile + 100 free texts + free camcorder reply or call 08000930705.	1
+130	urgent! your mobile no. was awarded ??2000 bonus caller prize on 5/9/03 this is our final try to contact u! call from landline 09064019788 box42wr29c, 150ppm	1
+131	eh u remember how 2 spell his name... yes i did. he v naughty make until i v wet.	0
+132	aight, can you text me the address?	0
+133	hope things went well at 'doctors' ;) reminds me i still need 2go.did u c d little thing i left in the lounge?	0
+134	, how's things? just a quick question.	0
+135	today is \song dedicated day..\" which song will u dedicate for me? send this to all ur valuable frnds but first rply me..."	0
+136	please call our customer service representative on freephone 0808 145 4742 between 9am-11pm as you have won a guaranteed ??1000 cash or ??5000 prize!	1
+137	what you need. you have a person to give na.	0
+138	but you were together so you should be thinkin about him	0
+139	don't make life too stressfull.. always find time to laugh.. it may not add years to your life! but surely adds more life to ur years!! gud ni8..swt dreams..	0
+140	at home also.	0
+141	house-maid is the murderer, coz the man was murdered on  &lt;#&gt; th january.. as public holiday all govt.instituitions are closed,including post office..	0
+142	free for 1st week! no1 nokia tone 4 ur mobile every week just txt nokia to 8077 get txting and tell ur mates. www.getzed.co.uk pobox 36504 w45wq 16+ norm150p/tone	1
+143	mark works tomorrow. he gets out at 5. his work is by your house so he can meet u afterwards.	0
+144	ok . . now i am in bus. . if i come soon i will come otherwise tomorrow	0
+145	thanks for picking up the trash.	0
+146	so when do you wanna gym?	0
+147	win a ??1000 cash prize or a prize worth ??5000	1
+148	sometimes we put walls around our hearts,not just to be safe from getting hurt.. but to find out who cares enough to break the walls &amp; get closer.. goodnoon:)	0
+149	y?where u at dogbreath? its just sounding like jan c that??s al!!!!!!!!!	0
+150	shall i ask one thing if you dont mistake me.	0
+151	forgot it takes me 3 years to shower, sorry. where you at/your phone dead yet?	0
+152	hey so whats the plan this sat?	0
+153	says that he's quitting at least5times a day so i wudn't take much notice of that. nah, she didn't mind. are you gonna see him again? do you want to come to taunton tonight? u can tell me all about !	0
+154	why she wants to talk to me	0
+155	sounds great! are you home now?	0
+156	watching cartoon, listening music &amp; at eve had to go temple &amp; church.. what about u?	0
+157	k:)k.are you in college?	0
+158	pls help me tell ashley that i cant find her number oh	0
+159	convey my regards to him	0
+160	thanks for your ringtone order, reference number x49. your mobile will be charged 4.50. should your tone not arrive please call customer services 09065989182. from: [colour=red]text[/colour]txtstar	1
+161	ok leave no need to ask	0
+162	can do lor...	0
+163	hey r ?_ still online? i've finished the formatting...	0
+164	i don't think he has spatula hands!	0
+165	of course. i guess god's just got me on hold right now.	0
+166	oic cos me n my sis got no lunch today my dad went out... so dunno whether 2 eat in sch or wat...	0
+167	when ?_ login dat time... dad fetching ?_ home now?	0
+168	k come to nordstrom when you're done	0
+169	k.k:)i'm going to tirunelvali this week to see my uncle ..i already spend the amount by taking dress .so only i want money.i will give it on feb 1	0
+170	hey happy birthday...	0
+171	we confirm eating at esplanade?	0
+172	goodnight, sleep well da please take care pa. please.	0
+173	sms. ac blind date 4u!: rodds1 is 21/m from aberdeen, united kingdom. check him out http://img. sms. ac/w/icmb3cktz8r7!-4 no blind dates send hide	1
+174	annoying isn't it.	0
+175	i got arrested for possession at, i shit you not,  &lt;time&gt;  pm	0
+176	she's borderline but yeah whatever.	0
+177	i am in a marriage function	0
+178	how u doin baby girl ?? hope u are okay every time i call ure phone is off! i miss u get in touch	0
+179	want the latest video handset? 750 anytime any network mins? half price line rental? reply or call 08000930705 for delivery tomorrow	1
+180	i'm done...	0
+181	i might go 2 sch. yar at e salon now v boring.	0
+182	lemme know when you're here	0
+183	hope you are having a great day.	0
+184	aathi..where are you dear..	0
+185	she said,'' do u mind if i go into the bedroom for a minute ? '' ''ok'', i sed in a sexy mood. she came out 5 minuts latr wid a cake...n my wife,	0
+186	i'm now but have to wait till 2 for the bus to pick me.	0
+187	what you thinked about me. first time you saw me in class.	0
+188	urgent! we are trying to contact you. last weekends draw shows that you have won a ??900 prize guaranteed. call 09061701851. claim code k61. valid 12hours only	1
+189	i'm freezing and craving ice. fml	0
+190	well. im computerless. time to make some oreo truffles	0
+191	alright, see you in a bit	0
+192	den only weekdays got special price... haiz... cant eat liao... cut nails oso muz wait until i finish drivin wat, lunch still muz eat wat...	0
+193	yup. thk of u oso boring wat.	0
+194	no he didn't. spring is coming early yay!	0
+195	just got up. have to be out of the room very soon. ??_. i hadn't put the clocks back til at 8 i shouted at everyone to get up and then realised it was 7. wahay. another hour in bed.	0
+196	why don't you go tell your friend you're not sure you want to live with him because he smokes too much then spend hours begging him to come smoke	0
+197	dont forget you can place as many free requests with 1stchoice.co.uk as you wish. for more information call 08707808226.	1
+198	k..k:)where are you?how did you performed?	0
+199	yun buying... but school got offer 2000 plus only...	0
+200	anything lor is she coming?	0
+201	haha but no money leh... later got to go for tuition... haha and looking for empty slots for driving lessons	0
+202	i donno its in your genes or something	0
+203	k.then any other special?	0
+204	okie ?_ wan meet at bishan? cos me at bishan now. i'm not driving today.	0
+205	ok lor then we go tog lor...	0
+206	a ??400 xmas reward is waiting for you! our computer has randomly picked you from our loyal mobile customers to receive a ??400 reward. just call 09066380611	1
+207	today is sorry day.! if ever i was angry with you, if ever i misbehaved or hurt you? plz plz just slap urself bcoz, its ur fault, i'm basically good	0
+208	i'm coming home 4 dinner.	0
+209	i will spoil you in bed as well :)	0
+210	have you not finished work yet or something?	0
+211	hi. happy new year. i dont mean to intrude but can you pls let me know how much tuition you paid last semester and how much this semester is. thanks	0
+212	dont pack what you can buy at any store.like cereals. if you must pack food, pack gari or something 9ja that you will miss.	0
+213	wow so healthy. old airport rd lor. cant thk of anything else. but i'll b bathing my dog later.	0
+214	not much, just some textin'. how bout you?	0
+215	nimbomsons. yep phone knows that one. obviously, cos thats a real word	0
+216	din i tell u jus now 420	0
+217	dont give a monkeys wot they think and i certainly don't mind. any friend of mine&all that! just don't sleep wiv , that wud be annoyin!	0
+218	you'd like that wouldn't you? jerk!	0
+219	my supervisor find 4 me one lor i thk his students. i havent ask her yet. tell u aft i ask her.	0
+220	goin to workout lor... muz lose e fats...	0
+221	wot u up 2? thout u were gonna call me!! txt bak luv k	0
+222	was gr8 to see that message. so when r u leaving? congrats dear. what school and wat r ur plans.	0
+223	i need... coz i never go before	0
+224	\er	0
+225	i got a call from a landline number. . . i am asked to come to anna nagar . . . i will go in the afternoon	0
+226	you are everywhere dirt, on the floor, the windows, even on my shirt. and sometimes when i open my mouth, you are all that comes flowing out. i dream of my world without you, then half my chores are out too. a time of joy for me, lots of tv shows i.ll see. but i guess like all things you just must exist, like rain, hail and mist, and when my time here is done, you and i become one.	0
+227	goodmorning, today i am late for  &lt;#&gt; min.	0
+228	no de.am seeing in online shop so that i asked.	0
+229	pick you up bout 7.30ish? what time are  and that going?	0
+230	i'm not coming home 4 dinner.	0
+231	i will reach ur home in  &lt;#&gt;  minutes	0
+232	but i'm really really broke oh. no amount is too small even  &lt;#&gt;	0
+233	huh so early.. then ?_ having dinner outside izzit?	0
+234	dont put your phone on silent mode ok	0
+235	beerage?	0
+236	i prefer my free days... tues, wed, fri oso can... ?? ask those workin lor...	0
+237	hi where you. you in home or calicut?	0
+238	also maaaan are you missing out	0
+239	am in gobi arts college	0
+240	whens your radio show?	0
+241	and of course you should make a stink!	0
+242	today is accept day..u accept me as? brother sister lover dear1 best1 clos1 lvblefrnd jstfrnd cutefrnd lifpartnr belovd swtheart bstfrnd no rply means enemy	0
+243	final chance! claim ur ??150 worth of discount vouchers today! text yes to 85023 now! savamob, member offers mobile! t cs savamob pobox84, m263uz. ??3.00 subs 16	1
+244	height of confidence: all the aeronautics professors wer calld &amp; they wer askd 2 sit in an aeroplane. aftr they sat they wer told dat the plane ws made by their students. dey all hurried out of d plane.. bt only 1 didnt move... he said:\if it is made by my students	0
+245	fuck babe, i miss you sooooo much !! i wish you were here to sleep with me ... my bed is so lonely ... i go now, to sleep ... to dream of you, my love ...	0
+246	ok no prob...	0
+247	wiskey brandy rum gin beer vodka scotch shampain wine \kudi\"yarasu dhina vaazhthukkal. .."	0
+248	its not that time of the month nor mid of the time?	0
+249	i'm leaving my house now.	0
+250	prepare to be pleasured :)	0
+251	what today-sunday..sunday is holiday..so no work..	0
+252	yup ok...	0
+253	u call me alter at 11 ok.	0
+254	hey i'm bored... so i'm thinking of u... so wat r u doing?	0
+255	probably earlier than that if the station's where i think it is	0
+256	even my brother is not like to speak with me. they treat me like aids patent.	0
+257	how's ur paper?	0
+258	haha awesome, i've been to 4u a couple times. who all's coming?	0
+259	love isn't a decision, it's a feeling. if we could decide who to love, then, life would be much simpler, but then less magical	0
+260	great. never been better. each day gives even more reasons to thank god	0
+261	i have to take exam with in march 3	0
+262	\are you comingdown later?\""	0
+263	this pen thing is beyond a joke. wont a biro do? don't do a masters as can't do this ever again!	0
+264	you available now? i'm like right around hillsborough &amp;  &lt;#&gt; th	0
+265	i'll get there at 3, unless you guys want me to come some time sooner	0
+266	do you want bold 2 or bb torch	0
+267	text banneduk to 89555 to see! cost 150p textoperator g696ga 18+ xxx	1
+268	private! your 2003 account statement for shows 800 un-redeemed s. i. m. points. call 08719899230 identifier code: 41685 expires 07/11/04	1
+269	i wil be there with in  &lt;#&gt;  minutes. got any space	0
+270	captain is in our room:)	0
+271	dear hero,i am leaving to qatar tonite for an apt opportunity.pls do keep in touch at  &lt;email&gt; ,kerala	0
+272	pete,is this your phone still? its jenny from college and leanne.what are you up to now?:)	0
+273	hi babe u r most likely to be in bed but im so sorry about tonight! i really wanna see u tomorrow so call me at 9. love me xxx	0
+274	that's good, because i need drugs	0
+275	you sure your neighbors didnt pick it up	0
+276	i am back. bit long cos of accident on a30. had to divert via wadebridge.i had a brilliant weekend thanks. speak soon. lots of love	0
+277	dunno cos i was v late n when i reach they inside already... but we ate spageddies lor... it's e gals who r laughing at me lor...	0
+278	i will reach before ten morning	0
+279	no need lar i go engin? cos my sis at arts today...	0
+280	haha... sounds crazy, dunno can tahan anot...	0
+281	what is your record for one night? :)	0
+282	freemsg: fancy a flirt? reply date now & join the uks fastest growing mobile dating service. msgs rcvd just 25p to optout txt stop to 83021. reply date now!	1
+283	\hey sorry i didntgive ya a a bellearlier hunny	0
+284	u in town alone?	0
+285	that sucks. i'll go over so u can do my hair. you'll do it free right?	0
+286	unfortunately i've just found out that we have to pick my sister up from the airport that evening so don't think i'll be going out at all. we should try to go out one of th	0
+287	boo i'm on my way to my moms. she's making tortilla soup. yummmm	0
+288	when did i use soc... i use it only at home... ?? dunno how 2 type it in word ar...	0
+289	sorry, i'll call later	0
+290	of cos can lar i'm not so ba dao ok... 1 pm lor... y u never ask where we go ah... i said u would ask on fri but he said u will ask today...	0
+291	oh is it? send me the address	0
+292	pls dont forget to study	0
+293	no we put party 7 days a week and study lightly, i think we need to draw in some custom checkboxes so they know we're hardcore	0
+294	oh k.k..where did you take test?	0
+295	i pocked you up there before	0
+296	u don't know how stubborn i am. i didn't even want to go to the hospital. i kept telling mark i'm not a weak sucker. hospitals are for weak suckers.	0
+297	free entry in 2 a wkly comp to win fa cup final tkts 21st may 2005. text fa to 87121 to receive entry question(std txt rate)t&c's apply 08452810075over18's	1
+298	got it. seventeen pounds for seven hundred ml ??? hope ok.	0
+299	chk in ur belovd ms dict	0
+300	a boy was late 2 home. his father: \power of frndship\""	0
+301	as usual..iam fine, happy &amp; doing well..:)	0
+302	k..i deleted my contact that why?	0
+303	me fine..absolutly fine	0
+304	unlimited texts. limited minutes.	0
+305	i wish that i was with you. holding you tightly. making you see how important you are. how much you mean to me ... how much i need you ... in my life ...	0
+306	carlos says he'll be at mu in  &lt;#&gt;  minutes	0
+307	sad story of a man - last week was my b'day. my wife did'nt wish me. my parents forgot n so did my kids . i went to work. even my colleagues did not wish.	0
+308	for your chance to win a free bluetooth headset then simply reply back with \adp\""	1
+309	dont make ne plans for nxt wknd coz she wants us to come down then ok	0
+310	i'm really not up to it still tonight babe	0
+311	so u wan 2 come for our dinner tonight a not?	0
+312	would you like to see my xxx pics they are so hot they were nearly banned in the uk!	1
+313	honey boo i'm missing u.	0
+314	i know a few people i can hit up and fuck to the yes	0
+315	better. made up for friday and stuffed myself like a pig yesterday. now i feel bleh. but at least its not writhing pain kind of bleh.	0
+316	i'm not sure if its still available though	0
+317	good. no swimsuit allowed :)	0
+318	no no:)this is kallis home ground.amla home town is durban:)	0
+319	am going to take bath ill place the key in window:-)	0
+320	what happened to our yo date?	0
+321	will you come online today night	0
+322	k give me a sec, breaking a  &lt;#&gt;  at cstore	0
+323	i'm good. have you registered to vote?	0
+324	u have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09065171142-stopsms-08	1
+325	all done? all handed in? celebrations in full swing yet?	0
+326	did u got that persons story	0
+327	7 wonders in my world 7th you 6th ur style 5th ur smile 4th ur personality 3rd ur nature 2nd ur sms and 1st \ur lovely friendship\"... good morning dear"	0
+328	sorry, i'll call later	0
+329	please call 08712402779 immediately as there is an urgent message waiting for you	1
+330	olol i printed out a forum post by a guy with the exact same  prob which was fixed with a gpu replacement. hopefully they dont ignore that.	0
+331	now, whats your house # again ? and do you have any beer there ?	0
+332	i am on the way to ur home	0
+333	but really quite funny lor wat... then u shd haf run shorter distance wat...	0
+334	wishing you a beautiful day. each moment revealing even more things to keep you smiling. do enjoy it.	0
+335	its a valentine game. . . send dis msg to all ur friends. .. if 5 answers r d same then someone really loves u. ques- which colour suits me the best?rply me	0
+336	customer place i will call you.	0
+337	where are you lover ? i need you ...	0
+338	ummmmmaah many many happy returns of d day my dear sweet heart.. happy birthday dear	0
+339	aight do you still want to get money	0
+340	i have many dependents	0
+341	all e best 4 ur exam later.	0
+342	well, i was about to give up cos they all said no they didn???t do one nighters. i persevered and found one but it is very cheap so i apologise in advance. it is just somewhere to sleep isnt it?	0
+343	+123 congratulations - in this week's competition draw u have won the ??1450 prize to claim just call 09050002311 b4280703. t&cs/stop sms 08718727868. over 18 only 150ppm	1
+344	u still going to the mall?	0
+345	fffff. can you text kadeem or are you too far gone	0
+346	i feel like a dick because i keep sleeping through your texts and facebook messages. sup, you in town?	0
+347	i'm awake oh. what's up.	0
+348	so do you have samus shoulders yet	0
+349	no wonder... cos i dun rem seeing a silver car... but i thk i saw a black one...	0
+350	err... cud do. i'm going to  at 8pm. i haven't got a way to contact him until then.	0
+351	well. you know what i mean. texting	0
+352	omg it could snow here tonite!	0
+353	u can call me now...	0
+354	uncle g, just checking up on you. do have a rewarding month	0
+355	urgent this is our 2nd attempt to contact u. your ??900 prize from yesterday is still awaiting collection. to claim call now 09061702893	1
+356	hi if ur lookin 4 saucy daytime fun wiv busty married woman am free all next week chat now 2 sort time 09099726429 janinexx calls??1/minmobsmorelkpobox177hp51fl	1
+357	i ask if u meeting da ge tmr nite...	0
+358	i felt so...not any conveying reason.. ese he... what about me?	0
+359	home so we can always chat	0
+360	everybody had fun this evening. miss you.	0
+361	u have won a nokia 6230 plus a free digital camera. this is what u get when u win our free auction. to take part send nokia to 83383 now. pobox114/14tcr/w1 16	1
+362	that day you asked about anand number. why:-)	0
+363	i'm not coming over, do whatever you want	0
+364	ffffffffff. alright no way i can meet up with you sooner?	0
+365	yes baby! we can study all the positions of the kama sutra ;)	0
+366	tmrw. im finishing 9 doors	0
+367	i don't know u and u don't know me. send chat to 86688 now and let's find each other! only 150p/msg rcvd. hg/suite342/2lands/row/w1j6hl ldn. 18 years or over.	1
+368	urgent! your mobile no 07xxxxxxxxx won a ??2,000 bonus caller prize on 02/06/03! this is the 2nd attempt to reach you! call 09066362231 asap! box97n7qp, 150ppm	1
+369	hello hun how ru? its here by the way. im good. been on 2 dates with that guy i met in walkabout so far. we have to meet up soon. hows everyone else?	0
+370	says the  &lt;#&gt;  year old with a man and money. i'm down to my last  &lt;#&gt; . still waiting for that check.	0
+371	yes obviously, but you are the eggs-pert and the potato head??_ speak soon!	0
+372	for ur chance to win ??250 cash every wk txt: play to 83370. t's&c's www.music-trivia.net custcare 08715705022, 1x150p/wk.	1
+373	free entry to the gr8prizes wkly comp 4 a chance to win the latest nokia 8800, psp or ??250 cash every wk.txt great to 80878 http//www.gr8prizes.com 08715705022	1
+374	rose for red,red for blood,blood for heart,heart for u. but u for me.... send tis to all ur friends.. including me.. if u like me.. if u get back, 1-u r poor in relation! 2-u need some 1 to support 3-u r frnd 2 many 4-some1 luvs u 5+- some1 is praying god to marry u.:-) try it....	0
+375	hey mr whats the name of that bill brison book the one about language and words	0
+376	ya srsly better than yi tho	0
+377	ok which your another number	0
+378	uncle abbey! happy new year. abiola	0
+379	i am in tirupur da, once you started from office call me.	0
+380	hey! congrats 2u2. id luv 2 but ive had 2 go home!	0
+381	win a year supply of cds 4 a store of ur choice worth ??500 & enter our ??100 weekly draw txt music to 87066 ts&cs www.ldew.com.subs16+1win150ppmx3	1
+382	i just saw ron burgundy captaining a party boat so yeah	0
+383	where is that one day training:-)	0
+384	for sale - arsenal dartboard. good condition but no doubles or trebles!	1
+385	hey do you want anything to buy:)	0
+386	oh, yes, i've just been a little under the weather so i've kind of been coccooning at home	0
+387	hi babe its jordan, how r u? im home from abroad and lonely, text me back if u wanna chat xxsp visionsms.com text stop to stopcost 150p 08712400603	1
+388	private! your 2003 account statement for shows 800 un-redeemed s.i.m. points. call 08718738001 identifier code: 49557 expires 26/11/04	1
+389	u haven??t lost me ill always b here 4u.i didn??t intend 2 hurt u but i never knew how u felt about me when iwas+marine&that??s what itried2tell urmom.i careabout u	0
+390	january male sale! hot gay chat now cheaper, call 08709222922. national rate from 1.5p/min cheap to 7.8p/min peak! to stop texts call 08712460324 (10p/min)	1
+391	customer service announcement. we recently tried to make a delivery to you but were unable to do so, please call 07099833605 to re-schedule. ref:9280114	1
+392	good morning princess! have a great day!	0
+393	cool. so how come you havent been wined and dined before?	0
+394	got what it takes 2 take part in the wrc rally in oz? u can with lucozade energy! text rally le to 61200 (25p), see packs or lucozade.co.uk/wrc & itcould be u!	1
+395	bognor it is! should be splendid at this time of year.	0
+396	yar lor... keep raining non stop... or u wan 2 go elsewhere?	0
+397	ok anyway no need to change with what you said	0
+398	call from 08702490080 - tells u 2 call 09066358152 to claim ??5000 prize. u have 2 enter all ur mobile & personal details @ the prompts. careful!	1
+399	do you want 750 anytime any network mins 150 text and a new video phone for only five pounds per week call 08002888812 or reply for delivery tomorrow	1
+400	8007 free for 1st week! no1 nokia tone 4 ur mob every week just txt nokia to 8007 get txting and tell ur mates www.getzed.co.uk pobox 36504 w4 5wq norm 150p/tone 16+	1
+401	dai i downloaded but there is only exe file which i can only run that exe after installing.	0
+402	yes :)it completely in out of form:)clark also utter waste.	0
+403	oh and by the way you do have more food in your fridge! want to go out for a meal tonight?	0
+404	am slow in using biola's fne	0
+405	night night, see you tomorrow	0
+406	they r giving a second chance to rahul dengra.	0
+407	yes i thought so. thanks.	0
+408	ok then i will come to ur home after half an hour	0
+409	storming msg: wen u lift d phne, u say \hello\" do u knw wt is d real meaning of hello?? . . . it's d name of a girl..! . . . yes.. and u knw who is dat girl?? \"margaret hello\" she is d girlfrnd f grahmbell who invnted telphone... . . . . moral:one can 4get d name of a person	0
+410	hello from orange. for 1 month's free access to games, news and sport, plus 10 free texts and 20 photo messages, reply yes. terms apply: www.orange.co.uk/ow	1
+411	but that's on ebay it might be less elsewhere.	0
+412	k... must book a not huh? so going for yoga basic on sunday?	0
+413	sorry,in meeting i'll call later	0
+414	im good! i have been thinking about you...	0
+415	actually, my mobile is full of msg. and i m doing a work online, where i need to send them  &lt;#&gt;  sent msg i wil explain u later.	0
+416	i will take care of financial problem.i will help:)	0
+417	aiyo... u always c our ex one... i dunno abt mei, she haven reply... first time u reply so fast... y so lucky not workin huh, got bao by ur sugardad ah...gee..	0
+418	buy space invaders 4 a chance 2 win orig arcade game console. press 0 for games arcade (std wap charge) see o2.co.uk/games 4 terms + settings. no purchase	1
+419	hey tmr meet at bugis 930 ?	0
+420	i'm at home. please call	0
+421	sorry. || mail? ||	0
+422	can meh? thgt some will clash... really ah, i dun mind... i dun seen to have lost any weight... gee...	0
+423	todays voda numbers ending 7548 are selected to receive a $350 award. if you have a match please call 08712300220 quoting claim code 4041 standard rates app	1
+424	o ic lol. should play 9 doors sometime yo	0
+425	ur cash-balance is currently 500 pounds - to maximize ur cash-in now send collect to 83600 only 150p/msg. cc: 08718720201 po box 114/14 tcr/w1	1
+426	we have to pick rayan macleran there.	0
+427	keep my payasam there if rinu brings	0
+428	ok thanx...	0
+429	excellent, i'll see what riley's plans are	0
+430	nite nite pocay wocay luv u more than n e thing 4eva i promise ring u 2morrowxxxx	0
+431	sorry, i'll call later	0
+432	i want to sent  &lt;#&gt; mesages today. thats y. sorry if i hurts	0
+433	so you think i should actually talk to him? not call his boss in the morning? i went to this place last year and he told me where i could go and get my car fixed cheaper. he kept telling me today how much he hoped i would come back in, how he always regretted not getting my number, etc.	0
+434	do you want a new nokia 3510i colour phone delivered tomorrow? with 200 free minutes to any mobile + 100 free text + free camcorder reply or call 08000930705	1
+435	cold. dont be sad dear	0
+436	excellent! wish we were together right now!	0
+437	not course. only maths one day one chapter with in one month we can finish.	0
+438	yeah, in fact he just asked if we needed anything like an hour ago. when and how much?	0
+439	theoretically yeah, he could be able to come	0
+440	hi:)did you asked to waheeda fathima about leave?	0
+441	i will once i get home	0
+442	wow. you're right! i didn't mean to do that. i guess once i gave up on boston men and changed my search location to nyc, something changed. cuz on my signin page it still says boston.	0
+443	you've won tkts to the euro2004 cup final or ??800 cash, to collect call 09058099801 b4190604, pobox 7876150ppm	1
+444	mila, age23, blonde, new in uk. i look sex with uk guys. if u like fun with me. text mtalk to 69866.18 . 30pp/txt 1st 5free. ??1.50 increments. help08718728876	1
+445	\hi darlin i cantdo anythingtomorrow as myparents aretaking me outfor a meal. when are u free? katexxx\""	0
+446	it only does simple arithmetic not percentages.	0
+447	yeah i am, so i'll leave maybe 7ish?	0
+448	as in different styles?	0
+449	i need an 8th but i'm off campus atm, could i pick up in an hour or two?	0
+450	either way works for me. i am  &lt;#&gt;  years old. hope that doesnt bother you.	0
+451	yes, princess. are you going to make me moan?	0
+452	so what did the bank say about the money?	0
+453	congrats! 2 mobile 3g videophones r yours. call 09061744553 now! videochat wid ur mates, play java games, dload polyh music, noline rentl. bx420. ip4. 5we. 150pm	1
+454	lol .. *grins* .. i'm not babe, but thanks for thinking of me!	0
+455	nah, i'm a perpetual dd	0
+456	jus finish bathing...	0
+457	i don't know, same thing that's wrong everyso often, he panicks starts goin on bout not bein good enough ??_	0
+458	england v macedonia - dont miss the goals/team news. txt ur national team to 87077 eg england to 87077 try:wales, scotland 4txt/??1.20 poboxox36504w45wq 16+	1
+459	santa calling! would your little ones like a call from santa xmas eve? call 09058094583 to book your time.	1
+460	sir, hope your day is going smoothly. i really hoped i wont have to bother you about this. i have some bills that i can't settle this month. i am out of all extra cash. i know this is a challenging time for you also but i have to let you know.	0
+461	i always chat with you. in fact i need money can you raise me?	0
+462	sorry, i'll call later	0
+463	sad story of a man - last week was my b'day. my wife did'nt wish me. my parents forgot n so did my kids . i went to work. even my colleagues did not wish.	0
+464	going thru a very different feeling.wavering decisions and coping up with the same is the same individual.time will heal everything i believe.	0
+465	haha figures, well i found the piece and priscilla's bowl	0
+466	aiya we discuss later lar... pick ?_ up at 4 is it?	0
+467	pleassssssseeeeee tel me v avent done sportsx	0
+468	please call 08712402972 immediately as there is an urgent message waiting for you	1
+469	today's offer! claim ur ??150 worth of discount vouchers! text yes to 85023 now! savamob, member offers mobile! t cs 08717898035. ??3.00 sub. 16 . unsub reply x	1
+470	so many people seems to be special at first sight, but only very few will remain special to you till your last sight.. maintain them till life ends.. take cr da	0
+471	i have had two more letters from . i will copy them for you cos one has a message for you. speak soon	0
+472	i wont. so wat's wit the guys	0
+473	rock yr chik. get 100's of filthy films &xxx pics on yr phone now. rply filth to 69669. saristar ltd, e14 9yt 08701752560. 450p per 5 days. stop2 cancel	1
+474	3. you have received your mobile content. enjoy	1
+475	hope you are having a good week. just checking in	0
+476	\cheers for callin babe.sozi culdnt talkbut i wannatell u details later wenwecan chat properly x\""	0
+477	it's ok i wun b angry. msg u aft i come home tonight.	0
+478	how? izzit still raining?	0
+479	yo yo yo byatch whassup?	0
+480	aiyo u so poor thing... then u dun wan 2 eat? u bathe already?	0
+481	\hey babe! far 2 spun-out 2 spk at da mo... dead 2 da wrld. been sleeping on da sofa all day  tx 4 fonin hon  call 2mwen im bk frmcloud 9! j x\""	0
+482	anyway i'm going shopping on my own now. cos my sis not done yet. dun disturb u liao.	0
+483	tell her i said eat shit.	0
+484	miles and smiles r made frm same letters but do u know d difference..? smile on ur face keeps me happy even though i am miles away from u.. :-)keep smiling.. good nyt	0
+485	i jokin oni lar.. ?? busy then i wun disturb ?_.	0
+486	they are just making it easy to pay back. i have  &lt;#&gt; yrs to say but i can pay back earlier. you get?	0
+487	cool. do you like swimming? i have a pool and jacuzzi at my house.	0
+488	my superior telling that friday is leave for all other department except ours:)so it will be leave for you:)any way call waheed fathima hr and conform it:)	0
+489	what you did in  leave.	0
+490	you lifted my hopes with the offer of money. i am in need. especially when the end of the month approaches and it hurts my studying. anyways have a gr8 weekend	0
+491	yup. anything lor, if u dun wan it's ok...	0
+492	(i should add that i don't really care and if you can't i can at least get this dude to fuck off but hey, your money if you want it)	0
+493	4 tacos + 1 rajas burrito, right?	0
+494	no calls..messages..missed calls	0
+495	wat's da model num of ur phone?	0
+496	is ur changes 2 da report big? cos i've already made changes 2 da previous report.	0
+497	ok ok take care. i can understand.	0
+498	sorry, i'll call later	0
+499	you could have seen me..i did't recognise you face.:)	0
diff --git a/data/sms/tfidf_feats/test_feats.pickle.gz b/data/sms/tfidf_feats/test_feats.pickle.gz
new file mode 100644
index 0000000000000000000000000000000000000000..4b7b9dad4dfa0cae0e2a43c76dc2cb0cd4b94c42
Binary files /dev/null and b/data/sms/tfidf_feats/test_feats.pickle.gz differ
diff --git a/data/sms/tfidf_feats/train_feats.pickle.gz b/data/sms/tfidf_feats/train_feats.pickle.gz
new file mode 100644
index 0000000000000000000000000000000000000000..685e308086578a966ebf93ce1a233fda0debd3b7
Binary files /dev/null and b/data/sms/tfidf_feats/train_feats.pickle.gz differ
diff --git a/data/sms/tfidf_feats/valid_feats.pickle.gz b/data/sms/tfidf_feats/valid_feats.pickle.gz
new file mode 100644
index 0000000000000000000000000000000000000000..ffe401cea9be44b90d1ca27db3791fbe6a4a53cd
Binary files /dev/null and b/data/sms/tfidf_feats/valid_feats.pickle.gz differ
diff --git a/data/sms/train.tsv b/data/sms/train.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..68a9eef0eccc8c42ff14303632414f681c281e55
--- /dev/null
+++ b/data/sms/train.tsv
@@ -0,0 +1,4503 @@
+index	sentence1	label
+0	oh ho. is this the first time u use these type of words	0
+1	christmas is an occasion that is celebrated as a reflection of ur... values..., desires..., affections...&amp; traditions.... have an ideal christmas...	0
+2	da my birthdate in certificate is in april but real date is today. but dont publish it. i shall give you a special treat if you keep the secret. any way thanks for the wishes	0
+3	thk shld b can... ya, i wana go 4 lessons... haha, can go for one whole stretch...	0
+4	watching tv lor. nice one then i like lor.	0
+5	would really appreciate if you call me. just need someone to talk to.	0
+6	ur cash-balance is currently 500 pounds - to maximize ur cash-in now send go to 86688 only 150p/msg. cc 08718720201 hg/suite342/2lands row/w1j6hl	1
+7	the search 4 happiness is 1 of d main sources of unhappiness! accept life the way it comes! u will find happiness in every moment u live.	0
+8	never try alone to take the weight of a tear that comes out of ur heart and falls through ur eyes... always remember a stupid friend is here to share... bslvyl	0
+9	reverse is cheating. that is not mathematics.	0
+10	just hopeing that wasn???t too pissed up to remember and has gone off to his sisters or something!	0
+11	ugh i don't wanna get out of bed. it's so warm.	0
+12	... are you in the pub?	0
+13	how much is blackberry bold2 in nigeria.	0
+14	sorry, i'll call later in meeting any thing related to trade please call arul. &lt;#&gt;	0
+15	i am great! how are you?	0
+16	so when do you wanna gym harri	0
+17	try to do something dear. you read something for exams	0
+18	dear how is chechi. did you talk to her	0
+19	ok lor ?_ reaching then message me.	0
+20	call me when u're done...	0
+21	what should i eat fo lunch senor	0
+22	tiwary to rcb.battle between bang and kochi.	0
+23	hmm... dunno leh, mayb a bag 4 goigng out dat is not too small. or jus anything except perfume, smth dat i can keep.	0
+24	wn u r hurt by d prsn who s close 2 u, do fight wit dem. coz somtimes dis fight saves a relation bt being quiet leaves nothin in a relation.. gud eveb-)	0
+25	oh thanks a lot..i already bought 2 eggs ..	0
+26	pdate_now - double mins and 1000 txts on orange tariffs. latest motorola, sonyericsson & nokia & bluetooth free! call mobileupd8 on 08000839402 or call2optout/!yhl	1
+27	guess what! somebody you know secretly fancies you! wanna find out who it is? give us a call on 09065394973 from landline datebox1282essexcm61xn 150p/min 18	1
+28	oh you got many responsibilities.	0
+29	ok.. ?? finishing soon?	0
+30	i'm sorry. i've joined the league of people that dont keep in touch. you mean a great deal to me. you have been a friend at all times even at great personal cost. do have a great week.|	0
+31	its ok chikku, and its my 1 of favourite song..:-)	0
+32	i don't know u and u don't know me. send chat to 86688 now and let's find each other! only 150p/msg rcvd. hg/suite342/2lands/row/w1j6hl ldn. 18 years or over.	1
+33	sorry, i'll call later	0
+34	you are right though. i can't give you the space you want and need. this is really starting to become an issue. i was going to suggest setting a definite move out--if i'm still there-- after greece. but maybe you are ready and should do it now.	0
+35	we have all rounder:)so not required:)	0
+36	urgh, coach hot, smells of chip fat! thanks again, especially for the duvet (not a predictive text word).	0
+37	chile, please! it's only a  &lt;decimal&gt;  hour drive for me. i come down all the time and will be subletting feb-april for audition season.	0
+38	yeah. i got a list with only u and joanna if i'm feeling really anti social	0
+39	haha... yup hopefully  we will lose a few kg by mon. after hip hop can go orchard and weigh again	0
+40	say thanks2.	0
+41	no. it's not pride. i'm almost  &lt;#&gt;  years old and shouldn't be takin money from my kid. you're not supposed to have to deal with this stuff. this is grownup stuff--why i don't tell you.	0
+42	fighting with the world is easy, u either win or lose bt fightng with some1 who is close to u is dificult if u lose - u lose if u win - u still lose.	0
+43	just sing hu. i think its also important to find someone female that know the place well preferably a citizen that is also smart to help you navigate through. even things like choosing a phone plan require guidance. when in doubt ask especially girls.	0
+44	todays vodafone numbers ending with 4882 are selected to a receive a ??350 award. if your number matches call 09064019014 to receive your ??350 award.	1
+45	dunno, my dad said he coming home 2 bring us out 4 lunch. yup i go w u lor. i call u when i reach school lor...	0
+46	u 447801259231 have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09058094597	1
+47	still at west coast... haiz... ??'ll take forever to come back...	0
+48	ok then i come n pick u at engin?	0
+49	on a tuesday night r u 4 real	0
+50	dear where you. call me	0
+51	hi my darlin im on my way to london and we have just been smashed into by another driver! and have a big dent! im really missing u what have u been up to? xxx	0
+52	yup. wun believe wat? u really neva c e msg i sent shuhui?	0
+53	i know dat feelin had it with pete! wuld get with em , nuther place nuther time mayb?	0
+54	sorry, i'll call later	0
+55	if you still havent collected the dough pls let me know so i can go to the place i sent it to get the control number	0
+56	might ax well im there.	0
+57	that's the trouble with classes that go well - you're due a dodgey one ??_ expecting mine tomo! see you for recovery, same time, same place	0
+58	we walked from my moms. right on stagwood pass right on winterstone left on victors hill. address is &lt;#&gt;	0
+59	isn't frnd a necesity in life? imagine urself witout a frnd.. hw'd u feel at ur colleg? wat'll u do wth ur cell? wat abt functions? thnk abt events espe'll cared, missed &amp; irritated u? 4wrd it to all those dear-loving frnds wthout whom u cant live.. i jst did it.. takecare..:) goodmorning	0
+60	natalja (25/f) is inviting you to be her friend. reply yes-440 or no-440 see her: www.sms.ac/u/nat27081980 stop? send stop frnd to 62468	1
+61	didn't try, g and i decided not to head out	0
+62	lol no. just trying to make your day a little more interesting	0
+63	but i haf enuff space got like 4 mb...	0
+64	that depends. how would you like to be treated? :)	0
+65	see the forwarding message for proof	0
+66	nah it's straight, if you can just bring bud or drinks or something that's actually a little more useful than straight cash	0
+67	er yeah, i will b there at 15:26, sorry! just tell me which pub/cafe to sit in and come wen u can	0
+68	guess he wants alone time. we could just show up and watch when they do..	0
+69	no:-)i got rumour that you going to buy apartment in chennai:-)	0
+70	no. she's currently in scotland for that.	0
+71	i hope you know i'm still mad at you.	0
+72	if you can make it any time tonight or whenever you can it's cool, just text me whenever you're around	0
+73	haf u eaten? wat time u wan me 2 come?	0
+74	we are okay. going to sleep now. later	0
+75	we tried to contact you re your reply to our offer of a video phone 750 anytime any network mins half price line rental camcorder reply or call 08000930705	1
+76	i know you are thinkin malaria. but relax, children cant handle malaria. she would have been worse and its gastroenteritis. if she takes enough to replace her loss her temp will reduce. and if you give her malaria meds now she will just vomit. its a self limiting illness she has which means in a few days it will completely stop	0
+77	hi neva worry bout da truth coz the truth will lead me 2 ur heart. it??s the least a unique person like u deserve. sleep tight or morning	0
+78	i only haf msn. it's yijue@hotmail.com	0
+79	it is a good thing i'm now getting the connection to bw	0
+80	tee hee. off to lecture, cheery bye bye.	0
+81	tomorrow i am not going to theatre. . . so i can come wherever u call me. . . tell me where and when to come tomorrow	0
+82	how's it feel? mr. your not my real valentine just my yo valentine even tho u hardly play!!	0
+83	not from this campus. are you in the library?	0
+84	....photoshop makes my computer shut down.	0
+85	can you do online transaction?	0
+86	k.:)do it at evening da:)urgent:)	0
+87	nope. meanwhile she talk say make i greet you.	0
+88	or ill be a little closer like at the bus stop on the same street	0
+89	ok. not much to do here though. h&m friday, cant wait. dunno wot the hell im gonna do for another 3 weeks! become a slob- oh wait, already done that!	0
+90	you are always putting your business out there. you put pictures of your ass on facebook. you are one of the most open people i've ever met. why would i think a picture of your room would hurt you, make you feel violated.	0
+91	kinda. first one gets in at twelve! aah. speak tomo	0
+92	hi elaine, is today's meeting confirmed?	0
+93	call me. i m unable to cal. lets meet bhaskar, and deep	0
+94	no no. i will check all rooms befor activities	0
+95	\i;m reaching in another 2 stops.\""	0
+96	anyway holla at me whenever you're around because i need an excuse to go creep on people in sarasota	0
+97	do you know when dad will be back?	0
+98	hi, the sexychat girls are waiting for you to text them. text now for a great night chatting. send stop to stop this service	1
+99	check wid corect speling i.e. sarcasm	0
+100	wot about on wed nite i am 3 then but only til 9!	0
+101	how are you. wish you a great semester	0
+102	i want kfc its tuesday. only buy 2 meals only 2. no gravy. only 2 mark. 2!	0
+103	should i have picked up a receipt or something earlier	0
+104	ok... help me ask if she's working tmr a not?	0
+105	sunshine quiz wkly q! win a top sony dvd player if u know which country the algarve is in? txt ansr to 82277. ??1.50 sp:tyrone	1
+106	do we have any spare power supplies	0
+107	say this slowly.? god,i love you &amp; i need you,clean my heart with your blood.send this to ten special people &amp; u c miracle tomorrow, do it,pls,pls do it...	0
+108	misplaced your number and was sending texts to your old number. wondering why i've not heard from you this year. all the best in your mcat. got this number from my atlanta friends	0
+109	neft transaction with reference number  &lt;#&gt;  for rs. &lt;decimal&gt;  has been credited to the beneficiary account on  &lt;#&gt;  at  &lt;time&gt; : &lt;#&gt;	0
+110	today is accept day..u accept me as? brother sister lover dear1 best1 clos1 lvblefrnd jstfrnd cutefrnd lifpartnr belovd swtheart bstfrnd no rply means enemy	0
+111	dear, take care. i am just reaching home.love u a lot.	0
+112	ps u no ur a grown up now right?	0
+113	sent me de webadres for geting salary slip	0
+114	all day working day:)except saturday and sunday..	0
+115	happy new year to u too!	0
+116	hey ! i want you ! i crave you ! i miss you ! i need you ! i love you, ahmad saeed al hallaq ...	0
+117	dear subscriber ur draw 4 ??100 gift voucher will b entered on receipt of a correct ans. when was elvis presleys birthday? txt answer to 80062	1
+118	not yet chikku..going to room nw, i'm in bus..	0
+119	how much r ?_ willing to pay?	0
+120	or remind me in a few hrs.	0
+121	call me when you get the chance plz &lt;3	0
+122	ok then no need to tell me anything i am going to sleep good night	0
+123	sure! i am driving but will reach my destination soon.	0
+124	ok lor...	0
+125	is that what time you want me to come?	0
+126	i uploaded mine to facebook	0
+127	aiyar u so poor thing... i give u my support k... jia you! i'll think of u...	0
+128	sitting ard nothing to do lor. u leh busy w work?	0
+129	ur cash-balance is currently 500 pounds - to maximize ur cash-in now send cash to 86688 only 150p/msg. cc: 08718720201 po box 114/14 tcr/w1	1
+130	i was slept that time.you there?	0
+131	ok give me 5 minutes i think i see her. btw you're my alibi. you were cutting my hair the whole time.	0
+132	no da. . vijay going to talk in jaya tv	0
+133	tell where you reached	0
+134	why you dint come with us.	0
+135	up to ?_... ?? wan come then come lor... but i din c any stripes skirt...	0
+136	urgent! your mobile was awarded a ??1,500 bonus caller prize on 27/6/03. our final attempt 2 contact u! call 08714714011	1
+137	how's my loverboy doing ? what does he do that keeps him from coming to his queen, hmmm ? doesn't he ache to speak to me ? miss me desparately ?	0
+138	just sent it. so what type of food do you like?	0
+139	not a lot has happened here. feels very quiet. beth is at her aunts and charlie is working lots. just me and helen in at the mo. how have you been?	0
+140	and smile for me right now as you go and the world will wonder what you are smiling about and think your crazy and keep away from you ... *grins*	0
+141	oh:)as usual vijay film or its different?	0
+142	i'm home. ard wat time will u reach?	0
+143	how about getting in touch with folks waiting for company? just txt back your name and age to opt in! enjoy the community (150p/sms)	1
+144	come round, it's .	0
+145	i borrow ur bag ok.	0
+146	i dont know what to do to come out of this so only am ask questions like this dont mistake me.	0
+147	she said,'' do u mind if i go into the bedroom for a minute ? '' ''ok'', i sed in a sexy mood. she came out 5 minuts latr wid a cake...n my wife,	0
+148	lul im gettin some juicy gossip at the hospital. two nurses are talking about how fat they are gettin. and one thinks shes obese. oyea.	0
+149	then we gotta do it after that	0
+150	when people see my msgs, they think iam addicted to msging... they are wrong, bcoz they don\'t know that iam addicted to my sweet friends..!! bslvyl	0
+151	wow. i never realized that you were so embarassed by your accomodations. i thought you liked it, since i was doing the best i could and you always seemed so happy about \the cave\". i'm sorry i didn't and don't have more to give. i'm sorry i offered. i'm sorry your room was so embarassing."	0
+152	thanks for your ringtone order, ref number r836. your mobile will be charged ??4.50. should your tone not arrive please call customer services on 09065069154	1
+153	now thats going to ruin your thesis!	0
+154	height of confidence: all the aeronautics professors wer calld &amp; they wer askd 2 sit in an aeroplane. aftr they sat they wer told dat the plane ws made by their students. dey all hurried out of d plane.. bt only 1 didnt move... he said:\if it is made by my students	0
+155	i'm always looking for an excuse to be in the city.	0
+156	he's in lag. that's just the sad part but we keep in touch thanks to skype	0
+157	old orchard near univ. how about you?	0
+158	as in missionary hook up, doggy hook up, standing...|	0
+159	so gd got free ice cream... i oso wan...	0
+160	hey sathya till now we dint meet not even a single time then how can i saw the situation sathya.	0
+161	well she's in for a big surprise!	0
+162	i cant pick the phone right now. pls send a message	0
+163	s.this will increase the chance of winning.	0
+164	aight, you close by or still down around alex's place?	0
+165	nope but i'll b going 2 sch on fri quite early lor cos mys sis got paper in da morn :-)	0
+166	hey, can you tell me blake's address? carlos wanted me to meet him there but i got lost and he's not answering his phone	0
+167	k, want us to come by now?	0
+168	i was about to do it when i texted. i finished a long time ago and showered and er'ything!	0
+169	anytime lor...	0
+170	haha awesome, be there in a minute	0
+171	meeting u is my work. . . tel me when shall i do my work tomorrow	0
+172	have you laid your airtel line to rest?	0
+173	havent mus ask if u can 1st wat. of meet 4 lunch den u n him meet can already lor. or u wan 2 go ask da ge 1st then confirm w me asap?	0
+174	why didn't u call on your lunch?	0
+175	last chance 2 claim ur ??150 worth of discount vouchers-text yes to 85023 now!savamob-member offers mobile t cs 08717898035. ??3.00 sub. 16 . remove txt x or stop	1
+176	if you ask her or she say any please message.	0
+177	he also knows about lunch menu only da. . i know	0
+178	how tall are you princess?	0
+179	you at mu? you should try to figure out how much money everyone has for gas and alcohol, jay and i are trying to figure out our weed budget	0
+180	it doesnt make sense to take it there unless its free. if you need to know more, wikipedia.com	0
+181	no need for the drug anymore.	0
+182	ha. you don???t know either. i did a a clever but simple thing with pears the other day, perfect for christmas.	0
+183	almost there, see u in a sec	0
+184	apo all other are mokka players only	0
+185	mm yes dear look how i am hugging you both. :-p	0
+186	customer service annoncement. you have a new years delivery waiting for you. please call 07046744435 now to arrange delivery	1
+187	well there's a pattern emerging of my friends telling me to drive up and come smoke with them and then telling me that i'm a weed fiend/make them smoke too much/impede their doing other things so you see how i'm hesitant	0
+188	idc get over here, you are not weaseling your way out of this shit twice in a row	0
+189	s...i will take mokka players only:)	0
+190	how are you doing? hope you've settled in for the new school year. just wishin you a gr8 day	0
+191	yes.mum lookin strong:)	0
+192	didn't you get hep b immunisation in nigeria.	0
+193	r u still working now?	0
+194	hey i will be late... i'm at amk. need to drink tea or coffee	0
+195	dear are you angry i was busy dear	0
+196	i place all ur points on e cultures module already.	0
+197	rofl. its true to its name	0
+198	remember on that day..	0
+199	i think just yourself ??_thanks and see you tomo	0
+200	i call you later, don't have network. if urgnt, sms me.	0
+201	mystery solved! just opened my email and he's sent me another batch! isn't he a sweetie	0
+202	aight yo, dats straight dogg	0
+203	you have won a nokia 7250i. this is what you get when you win our free auction. to take part send nokia to 86021 now. hg/suite342/2lands row/w1jhl 16+	1
+204	ok both our days. so what are you making for dinner tonite? am i invited?	0
+205	cool. i am  &lt;#&gt;  inches long. hope you like them big!	0
+206	you do what all you like	0
+207	men always needs a beautiful, intelligent, caring, loving, adjustable, cooperative wife. but the law allows only one wife....	0
+208	anyway seriously hit me up when you're back because otherwise i have to light up with armand and he always has shit and/or is vomiting	0
+209	it shall be fine. i have avalarr now. will hollalater	0
+210	a lot of this sickness thing going round. take it easy. hope u feel better soon. lol	0
+211	* will be september by then!	0
+212	moji i love you more than words. have a rich day	0
+213	great. p diddy is my neighbor and comes for toothpaste every morning	0
+214	i dont know why she.s not getting your messages	0
+215	i will see in half an hour	0
+216	http//tms. widelive.com/index. wml?id=820554ad0a1705572711&first=true??c c ringtone??	1
+217	super msg da:)nalla timing.	0
+218	for me the love should start with attraction.i should feel that i need her every time around me.she should be the first thing which comes in my thoughts.i would start the day and end it with her.she should be there every time i dream.love will be then when my every breath has her name.my life should happen around her.my life will be named to her.i would cry for her.will give all my happiness and take all her sorrows.i will be ready to fight with anyone for her.i will be in love when i will be doing the craziest things for her.love will be when i don't have to proove anyone that my girl is the most beautiful lady on the whole planet.i will always be singing praises for her.love will be when i start up making chicken curry and end up makiing sambar.life will be the most beautiful then.will get every morning and thank god for the day because she is with me.i would like to say a lot..will tell later..	0
+219	44 7732584351, do you want a new nokia 3510i colour phone deliveredtomorrow? with 300 free minutes to any mobile + 100 free texts + free camcorder reply or call 08000930705.	1
+220	will it help if we propose going back again tomorrow	0
+221	no, i was trying it all weekend ;v	0
+222	wat r u doing?	0
+223	please call amanda with regard to renewing or upgrading your current t-mobile handset free of charge. offer ends today. tel 0845 021 3680 subject to t's and c's	1
+224	from www.applausestore.com monthlysubscription@50p/msg max6/month t&csc web age16 2stop txt stop	1
+225	oh thats late! well have a good night and i will give u a call tomorrow. iam now going to go to sleep night night	0
+226	goodmorning today i am late for  &lt;decimal&gt; min.	0
+227	i'm working technical support :)voice process.	0
+228	yes..gauti and sehwag out of odi series.	0
+229	want a new video phone? 750 anytime any network mins? half price line rental free text for 3 months? reply or call 08000930705 for free delivery	1
+230	hi i'm sue. i am 20 years old and work as a lapdancer. i love sex. text me live - i'm i my bedroom now. text sue to 89555. by textoperator g2 1da 150ppmsg 18+	1
+231	hi :)finally i completed the course:)	0
+232	i am 6 ft. we will be a good combination!	0
+233	wot student discount can u get on books?	0
+234	i'm in school now n i'll be in da lab doing some stuff give me a call when ?_ r done.	0
+235	then ur sis how?	0
+236	yes princess! i want to make you happy...	0
+237	oh yeah,and hav a great time in newquay-send me a postcard !1 look after all the girls while im gone(u know the 1im talkin bout!)xx	0
+238	do you realize that in about 40 years, we'll have thousands of old ladies running around with tattoos?	1
+239	i've reached already.	0
+240	taka lor. wat time u wan 2 come n look 4 us?	0
+241	babe! how goes that day ? what are you up to ? i miss you already, my love ... * loving kiss* ... i hope everything goes well.	0
+242	thank you so much. when we skyped wit kz and sura, we didnt get the pleasure of your company. hope you are good. we've given you ultimatum oh! we are countin down to aburo. enjoy! this is the message i sent days ago	0
+243	smith waste da.i wanna gayle.	0
+244	fantasy football is back on your tv. go to sky gamestar on sky active and play ??250k dream team. scoring starts on saturday, so register now!sky opt out to 88088	1
+245	yeah jay's sort of a fucking retard	0
+246	superb thought- \be grateful that u dont have everything u want. that means u still have an opportunity to be happier tomorrow than u are today.\":-)"	0
+247	no no. i will check all rooms befor activities	0
+248	but i have to. i like to have love and arrange.	0
+249	hurt me... tease me... make me cry... but in the end of my life when i die plz keep one rose on my grave and say stupid i miss u.. have a nice day bslvyl	0
+250	long after i quit. i get on only like 5 minutes a day as it is.	0
+251	congratulations! thanks to a good friend u have won the ??2,000 xmas prize. 2 claim is easy, just call 08718726978 now! only 10p per minute. bt-national-rate	1
+252	ya, i'm referin to mei's ex wat... no ah, waitin 4 u to treat, somebody shld b rich liao...so gd, den u dun have to work frm tmr onwards...	0
+253	sir, good morning. hope you had a good weekend. i called to let you know that i was able to raise  &lt;#&gt;  from my dad. he however said he would make the rest available by mid feb. this amount is still quite short and i was hoping you would help. do have a good day. abiola	0
+254	stop calling everyone saying i might have cancer. my throat hurts to talk. i can't be answering everyones calls. if i get one more call i'm not babysitting on monday	0
+255	i'm tired of arguing with you about this week after week. do what you want and from now on, i'll do the same.	0
+256	do ?_ noe if ben is going?	0
+257	do you always celebrate ny's with your family ?	0
+258	18 days to euro2004 kickoff! u will be kept informed of all the latest news and results daily. unsubscribe send get euro stop to 83222.	1
+259	finish already... yar they keep saying i mushy... i so embarrassed ok...	0
+260	did u receive my msg?	0
+261	so dont use hook up any how	0
+262	do u knw dis no. &lt;#&gt; ?	0
+263	machan you go to gym tomorrow,  i wil come late goodnight.	0
+264	he remains a bro amongst bros	0
+265	enjoy the jamster videosound gold club with your credits for 2 new videosounds+2 logos+musicnews! get more fun from jamster.co.uk! 16+only help? call: 09701213186	1
+266	your free ringtone is waiting to be collected. simply text the password \mix\" to 85069 to verify. get usher and britney. fml  mk17 92h. 450ppw 16"	1
+267	never blame a day in ur life. good days give u happiness. bad days give u experience. both are essential in life! all are gods blessings! good morning.:	0
+268	i'm taking derek &amp; taylor to walmart, if i'm not back by the time you're done just leave the mouse on my desk and i'll text you when priscilla's ready	0
+269	no. on the way home. so if not for the long dry spell the season would have been over	0
+270	lol where do u come up with these ideas?	0
+271	can help u swoop by picking u up from wherever ur other birds r meeting if u want.	0
+272	hey check it da. i have listed da.	0
+273	dear reached railway. what happen to you	0
+274	from tomorrow onwards eve 6 to 3 work.	0
+275	got c... i lazy to type... i forgot ?_ in lect... i saw a pouch but like not v nice...	0
+276	also remember to get dobby's bowl from your car	0
+277	so there's a ring that comes with the guys costumes. it's there so they can gift their future yowifes. hint hint	0
+278	i wasn't well babe, i have swollen glands at my throat ... what did you end up doing ?	0
+279	morning only i can ok.	0
+280	k, text me when you're on the way	0
+281	purity of friendship between two is not about smiling after reading the forwarded message..its about smiling just by seeing the name. gud evng	0
+282	yo you guys ever figure out how much we need for alcohol? jay and i are trying to figure out how much we can safely spend on weed	0
+283	i've not called you in a while. this is hoping it was l8r malaria and that you know that we miss you guys. i miss bani big, so pls give her my love especially. have a great day.	0
+284	anything...	0
+285	lmao but its so fun...	0
+286	hi. hope ur day * good! back from walk, table booked for half eight. let me know when ur coming over.	0
+287	i taught that ranjith sir called me. so only i sms like that. becaus hes verifying about project. prabu told today so only pa dont mistake me..	0
+288	we live in the next  &lt;#&gt; mins	0
+289	v nice! off 2 sheffield tom 2 air my opinions on categories 2 b used 2 measure ethnicity in next census. busy transcribing. :-)	0
+290	when you came to hostel.	0
+291	u need my presnts always bcz u cant mis love. \jeevithathile irulinae neekunna prakasamanu sneham\" prakasam ennal prabha 'that mns prabha is'love' got it. dont mis me...."	0
+292	lol ... oh no babe, i wont be sliding into your place after midnight, but thanks for the invite	0
+293	good morning my dear........... have a great &amp; successful day.	0
+294	he has lots of used ones babe, but the model doesn't help. youi have to bring it over and he'll match it up	0
+295	have you started in skye	0
+296	u've been selected to stay in 1 of 250 top british hotels - for nothing! holiday valued at ??350! dial 08712300220 to claim - national rate call. bx526, sw73ss	1
+297	k.k:)when are you going?	0
+298	crucify is c not s. you should have told me earlier.	0
+299	do well :)all will for little time. thing of good times ahead:	0
+300	you're gonna have to be way more specific than that	0
+301	dis is yijue. i jus saw ur mail. in case huiming havent sent u my num. dis is my num.	0
+302	cool, i'll text you when i'm on the way	0
+303	just woke up. yeesh its late. but i didn't fall asleep til &lt;#&gt; am :/	0
+304	sorry, i'll call later	0
+305	indeed and by the way it was either or - not both !	0
+306	oh my god. i'm almost home	0
+307	i ain't answerin no phone at what is actually a pretty reasonable hour but i'm sleepy	0
+308	hanks lotsly!	0
+309	wen ur lovable bcums angry wid u, dnt take it seriously.. coz being angry is d most childish n true way of showing deep affection, care n luv!.. kettoda manda... have nice day da.	0
+310	booked ticket for pongal?	0
+311	k will do, addie &amp; i are doing some art so i'll be here when you get home	0
+312	trust me. even if isn't there, its there.	0
+313	if you r @ home then come down within 5 min	0
+314	ah you see. you have to be in the lingo. i will let you know wot on earth it is when has finished making it!	0
+315	jane babes not goin 2 wrk, feel ill after lst nite. foned in already cover 4 me chuck.:-)	0
+316	send me the new number	0
+317	that would be good ??_ i'll phone you tomo lunchtime, shall i, to organise something?	0
+318	yeah do! don???t stand to close tho- you???ll catch something!	0
+319	heart is empty without love.. mind is empty without wisdom.. eyes r empty without dreams &amp; life is empty without frnds.. so alwys be in touch. good night &amp; sweet dreams	0
+320	k..k.:)congratulation ..	0
+321	me fine..absolutly fine	0
+322	great! i shoot big loads so get ready!	0
+323	no message..no responce..what happend?	0
+324	you still at grand prix?	0
+325	sorry, i can't text &amp; drive coherently, see you in twenty	0
+326	a famous quote : when you develop the ability to listen to 'anything' unconditionally without losing your temper or self confidence, it means you are ......... 'married'	0
+327	he's an adult and would learn from the experience. there's no real danger. i just dont like peeps using drugs they dont need. but no comment	0
+328	tell you what, if you make a little spreadsheet and track whose idea it was to smoke to determine who \smokes too much\" for the entire month of february	0
+329	if he started searching he will get job in few days.he have great potential and talent.	0
+330	im in inperialmusic listening2the weirdest track ever by??leafcutter john??-sounds like insects being molested&someone plumbing,remixed by evil men on acid!	0
+331	so can collect ur laptop?	0
+332	s....s...india going to draw the series after many years in south african soil..	0
+333	heart is empty without love.. mind is empty without wisdom.. eyes r empty without dreams &amp; life is empty without frnds.. so alwys be in touch. good night &amp; sweet dreams	0
+334	sir, i have been late in paying rent for the past few months and had to pay a $ &lt;#&gt;  charge. i felt it would be inconsiderate of me to nag about something you give at great cost to yourself and that's why i didnt speak up. i however am in a recession and wont be able to pay the charge this month hence my askin well ahead of month's end. can you please help. thanks	0
+335	love it! daddy will make you scream with pleasure! i am going to slap your ass with my dick!	0
+336	oi. ami parchi na re. kicchu kaaj korte iccha korche na. phone ta tul na. plz. plz.	0
+337	?? comin to fetch us oredi...	0
+338	cud u tell ppl im gona b a bit l8 cos 2 buses hav gon past cos they were full & im still waitin 4 1. pete x	0
+339	ceri u rebel! sweet dreamz me little buddy!! c ya 2moro! who needs blokes	0
+340	alright if you're sure, let me know when you're leaving	0
+341	you will go to walmart. i.ll stay.	0
+342	txt: call to no: 86888 & claim your reward of 3 hours talk time to use from your phone now! subscribe6gbp/mnth inc 3hrs 16 stop?txtstop www.gamb.tv	1
+343	when did dad get back.	0
+344	where r e meeting tmr?	0
+345	frnd s not juz a word.....not merely a relationship.....its a silent promise which says ... \ i will be with you \" wherevr.. whenevr.. forevr... gudnyt dear.."	0
+346	i want  &lt;#&gt;  rs da:)do you have it?	0
+347	what today-sunday..sunday is holiday..so no work..	0
+348	i'm nt goin, got somethin on, unless they meetin 4 dinner lor... haha, i wonder who will go tis time...	0
+349	thanks. fills me with complete calm and reassurance!	0
+350	i think you should go the honesty road. call the bank tomorrow. its the tough decisions that make us great people.	0
+351	cheers for the message zogtorius. i??ve been staring at my phone for an age deciding whether to text or not.	0
+352	hey are we going for the lo lesson or gym?	0
+353	fighting with the world is easy, u either win or lose bt fightng with some1 who is close to u is dificult if u lose - u lose if u win - u still lose.	0
+354	haha... where got so fast lose weight, thk muz go 4 a month den got effect... gee,later we go aust put bk e weight.	0
+355	cool. we will have fun practicing making babies!	0
+356	sos! any amount i can get pls.	0
+357	good afternoon, babe. how goes that day ? any job prospects yet ? i miss you, my love ... *sighs* ... :-(	0
+358	yup ok thanx...	0
+359	k...k:)why cant you come here and search job:)	0
+360	xmas iscoming & ur awarded either ??500 cd gift vouchers & free entry 2 r ??100 weekly draw txt music to 87066 tnc www.ldew.com1win150ppmx3age16subscription	1
+361	cos i want it to be your thing	0
+362	hope you are having a great new semester. do wish you the very best. you are made for greatness.	0
+363	* am on a train back from northampton so i'm afraid not!	0
+364	great. i'm in church now, will holla when i get out	0
+365	buy space invaders 4 a chance 2 win orig arcade game console. press 0 for games arcade (std wap charge) see o2.co.uk/games 4 terms + settings. no purchase	1
+366	so check your errors and if you had difficulties, do correction.	0
+367	no. its not specialisation. can work but its slave labor. will look for it this month sha cos no shakara 4 beggar.	0
+368	hi its kate how is your evening? i hope i can see you tomorrow for a bit but i have to bloody babyjontet! txt back if u can. :) xxx	0
+369	its just the effect of irritation. just ignore it	0
+370	match started.india  &lt;#&gt;  for 2	0
+371	oh ho. is this the first time u use these type of words	0
+372	webpage s not available!	0
+373	i'm on the bus. love you	0
+374	haha awesome, omw back now then	0
+375	hi happy birthday. hi hi hi hi hi hi hi	0
+376	winner!! as a valued network customer you have been selected to receivea ??900 prize reward! to claim call 09061701461. claim code kl341. valid 12 hours only.	1
+377	good morning princess! how are you?	0
+378	dear voucher holder, 2 claim this weeks offer, at your pc go to http://www.e-tlp.co.uk/expressoffer ts&cs apply.2 stop texts txt stop to 80062.	1
+379	my sister going to earn more than me da.	0
+380	you have 1 new voicemail. please call 08719181513.	1
+381	splashmobile: choose from 1000s of gr8 tones each wk! this is a subscrition service with weekly tones costing 300p. u have one credit - kick back and enjoy	1
+382	y de asking like this.	0
+383	no. did you multimedia message them or e-mail?	0
+384	have you emigrated or something? ok maybe 5.30 was a bit hopeful...	0
+385	yo chad which gymnastics class do you wanna take? the site says christians class is full..	0
+386	true. its easier with her here.	0
+387	call me, i am senthil from hsbc.	0
+388	oh, my love, it's soooo good to hear from you. omg i missed you so much today. i'm sorry your having problems with the provider but thank you for tming me	0
+389	was the actual exam harder than nbme	0
+390	new theory: argument wins d situation, but loses the person. so dont argue with ur friends just.. . . . kick them &amp; say, i'm always correct.!	0
+391	hello. sort of out in town already. that . so dont rush home, i am eating nachos. will let you know eta.	0
+392	ur balance is now ??600. next question: complete the landmark, big, a. bob, b. barry or c. ben ?. text a, b or c to 83738. good luck!	1
+393	ya ok, vikky vl c witin  &lt;#&gt; mins and il reply u..	0
+394	please call our customer service representative on 0800 169 6031 between 10am-9pm as you have won a guaranteed ??1000 cash or ??5000 prize!	1
+395	oops sorry. just to check that you don't mind picking me up tomo at half eight from station. would that be ok?	0
+396	doing nothing, then u not having dinner w us?	0
+397	did you try making another butt.	0
+398	hi..i got the money da:)	0
+399	you in your room? i need a few	0
+400	please call 08712402902 immediately as there is an urgent message waiting for you.	1
+401	then she dun believe wat?	0
+402	urgent urgent! we have 800 free flights to europe to give away, call b4 10th sept & take a friend 4 free. call now to claim on 09050000555. ba128nnfwfly150ppm	1
+403	hey, looks like i was wrong and one of the kappa guys numbers is still on my phone, if you want i can text him and see if he's around	0
+404	i will cme i want to go to hos 2morow. after that i wil cme. this what i got from her dear what to do. she didnt say any time	0
+405	at what time should i come tomorrow	0
+406	s:)s.nervous  &lt;#&gt; :)	0
+407	leaving to qatar tonite in search of an opportunity.all went fast.pls add me in ur prayers dear.rakhesh	0
+408	your free ringtone is waiting to be collected. simply text the password \mix\" to 85069 to verify. get usher and britney. fml  mk17 92h. 450ppw 16"	1
+409	networking technical support associate.	0
+410	k..k..any special today?	0
+411	u having lunch alone? i now so bored...	0
+412	oh yah... we never cancel leh... haha	0
+413	i'm wif him now buying tix lar...	0
+414	are you planning to come chennai?	0
+415	does not operate after  &lt;#&gt;  or what	0
+416	i dont thnk its a wrong calling between us	0
+417	oh really? perform, write a paper, go to a movie and be home by midnight, huh?	0
+418	i called but no one pick up e phone. i ask both of them already they said ok.	0
+419	aight, i'm chillin in a friend's room so text me when you're on the way	0
+420	hmm, too many of them unfortunately... pics obviously arent hot cakes. its kinda fun tho	0
+421	haha... hope ?_ can hear the receipt sound... gd luck!	0
+422	you unbelievable faglord	0
+423	i'm an actor. when i work, i work in the evening and sleep late. since i'm unemployed at the moment, i always sleep late. when you're unemployed, every day is saturday.	0
+424	hello boytoy ! geeee ... i'm missing you today. i like to send you a tm and remind you i'm thinking of you ... and you are loved ... *loving kiss*	0
+425	havent.	0
+426	hmm .. bits and pieces lol ... *sighs* ...	0
+427	how are you. its been ages. how's abj	0
+428	whom you waited for yesterday	0
+429	me too. mark is taking forever to pick up my prescription and the pain is coming back.	0
+430	sorry i missed your call. can you please call back.	0
+431	ask g or iouri, i've told the story like ten times already	0
+432	freemsg: hey - i'm buffy. 25 and love to satisfy men. home alone feeling randy. reply 2 c my pix! qlynnbv help08700621170150p a msg send stop to stop txts	1
+433	baaaaabe! i misss youuuuu ! where are you ? i have to go and teach my class at 5 ...	0
+434	ok.	0
+435	yes..but they said its it.,	0
+436	cool breeze... bright sun... fresh flower... twittering birds... all these waiting to wish u: \goodmorning &amp; have a nice day\" :)"	0
+437	ya but it cant display internal subs so i gotta extract them	0
+438	urgent! call 09061749602 from landline. your complimentary 4* tenerife holiday or ??10,000 cash await collection sae t&cs box 528 hp20 1yf 150ppm 18+	1
+439	kindly send some one to our flat before  &lt;decimal&gt;  today.	0
+440	let's pool our money together and buy a bunch of lotto tickets. if we win i get &lt;#&gt; % u get &lt;#&gt; %. deal?	0
+441	badrith is only for chennai:)i will surely pick for us:)no competition for him.	0
+442	spending new years with my brother and his family. lets plan to meet next week. are you ready to be spoiled? :)	0
+443	sorry da thangam, very very sorry i am held up with prasad.	0
+444	urgent! you have won a 1 week free membership in our ??100,000 prize jackpot! txt the word: claim to no: 81010 t&c www.dbuk.net lccltd pobox 4403ldnw1a7rw18	1
+445	i am back. good journey! let me know if you need any of the receipts. shall i tell you like the pendent?	0
+446	awesome, plan to get here any time after like  &lt;#&gt; , i'll text you details in a wee bit	0
+447	i.ll post her out l8r. in class	0
+448	u can win ??100 of music gift vouchers every week starting now txt the word draw to 87066 tscs www.idew.com skillgame, 1winaweek, age16. 150ppermesssubscription	1
+449	i think chennai well settled?	0
+450	\ey! calm downon theacusations.. itxt u cos iwana know wotu r doin at thew/end... haventcn u in ages..ring me if ur up4 nething sat.love j xxx.\""	0
+451	otherwise had part time job na-tuition..	0
+452	send ur birthdate with month and year, i will tel u ur life partner's name. and the method of calculation. reply must.	0
+453	* am on my way	0
+454	s s..first time..dhoni rocks...	0
+455	fantasy football is back on your tv. go to sky gamestar on sky active and play ??250k dream team. scoring starts on saturday, so register now!sky opt out to 88088	1
+456	ya! when are ?_ taking ure practical lessons? i start in june..	0
+457	well done, blimey, exercise, yeah, i kinda remember wot that is, hmm.	0
+458	got it..mail panren paru..	0
+459	u free on sat rite? u wan 2 watch infernal affairs wif me n darren n mayb xy?	0
+460	ugh my leg hurts. musta overdid it on mon.	0
+461	my new years eve was ok. i went to a party with my boyfriend. who is this si then hey	0
+462	what part of \don't initiate\" don't you understand"	0
+463	stupid.its not possible	0
+464	k i'll be there before 4.	0
+465	pls speak to that customer machan.	0
+466	fighting with the world is easy, u either win or lose bt fightng with some1 who is close to u is dificult if u lose - u lose if u win - u still lose.	0
+467	i will be gentle princess! we will make sweet gentle love...	0
+468	ok chinese food on its way. when i get fat you're paying for my lipo.	0
+469	i meant middle left or right?	0
+470	hi, my love! how goes that day? fuck, this morning i woke and dropped my cell on the way down the stairs but it seems alright ... *phews* i miss you !	0
+471	dude avatar 3d was imp. at one point i thought there were actually flies in the room and almost tried hittng one as a reflex	0
+472	still work going on:)it is very small house.	0
+473	yes see ya not on the dot	0
+474	okay lor... wah... like that def they wont let us go... haha... what did they say in the terms and conditions?	0
+475	lmao you know me so well...	0
+476	not yet. just i'd like to keep in touch and it will be the easiest way to do that from barcelona. by the way how ru and how is the house?	0
+477	headin towards busetop	0
+478	oh right, ok. i'll make sure that i do loads of work during the day!  got a really nasty cough today and is dry n shot so that should really help it!	0
+479	shall i bring us a bottle of wine to keep us amused? only joking! i???ll bring one anyway	0
+480	ok... c ya...	0
+481	serious? what like proper tongued her	0
+482	as usual..iam fine, happy &amp; doing well..:)	0
+483	ever thought about living a good life with a perfect partner? just txt back name and age to join the mobile community. (100p/sms)	1
+484	yep get with the program. you're slacking.	0
+485	sir, i am waiting for your call.	0
+486	awesome, i remember the last time we got somebody high for the first time with diesel :v	0
+487	that is wondarfull song	0
+488	hi mate its rv did u hav a nice hol just a message 3 say hello coz haven??t sent u 1 in ages started driving so stay off roads!rvx	0
+489	how are you doing. how's the queen. are you going for the royal wedding	0
+490	what time you think you'll have it? need to know when i should be near campus	0
+491	for ur chance to win a ??250 wkly shopping spree txt: shop to 80878. t's&c's www.txt-2-shop.com custcare 08715705022, 1x150p/wk	1
+492	get a free mobile video player free movie. to collect text go to 89105. its free! extra films can be ordered t's and c's apply. 18 yrs only	1
+493	since when, which side, any fever, any vomitin.	0
+494	ok.ok ok..then..whats ur todays plan	0
+495	fran i decided 2 go n e way im completely broke an knackered i got up bout 3 c u 2mrw love janx p.s this is my dads fone, -no credit	0
+496	pls clarify back if an open return ticket that i have can be preponed for me to go back to kerala.	0
+497	nope... think i will go for it on monday... sorry i replied so late	0
+498	urgent! please call 09061743810 from landline. your abta complimentary 4* tenerife holiday or #5000 cash await collection sae t&cs box 326 cw25wx 150 ppm	1
+499	well its not like you actually called someone a punto. that woulda been worse.	0
+500	anything lor... u decide...	0
+501	i will cal you sir. in meeting	0
+502	there'll be a minor shindig at my place later tonight, you interested?	0
+503	just got some gas money, any chance you and the gang want to go on a grand nature adventure?	0
+504	its a laptop take it with you.	0
+505	do not b late love mum	0
+506	its sunny in california. the weather's just cool	0
+507	you'll not rcv any more msgs from the chat svc. for free hardcore services text go to: 69988 if u get nothing u must age verify with yr network & try again	1
+508	0a$networks allow companies to bill for sms, so they are responsible for their \suppliers\"	1
+509	i'm at bruce &amp; fowler now but i'm in my mom's car so i can't park (long story)	0
+510	now am free call me pa.	0
+511	ya they are well and fine., bbd(pooja) full pimples..even she become quite black..and ur rite here its too cold, wearing sweatter..	0
+512	oh sorry please its over	0
+513	actually i decided i was too hungry so i haven't left yet :v	0
+514	cancel cheyyamo?and get some money back?	0
+515	not that i know of, most people up here are still out of town	0
+516	urgent! we are trying to contact u. todays draw shows that you have won a ??800 prize guaranteed. call 09050001295 from land line. claim a21. valid 12hrs only	1
+517	ok im not sure what time i finish tomorrow but i wanna spend the evening with you cos that would be vewy vewy lubly! love me xxx	0
+518	orange brings you ringtones from all time chart heroes, with a free hit each week! go to ringtones & pics on wap. to stop receiving these tips reply stop.	1
+519	my fri ah... okie lor,goin 4 my drivin den go shoppin after tt...	0
+520	life alle mone,eppolum oru pole allalo	0
+521	hey babe, how's it going ? did you ever figure out where your going for new years ?	0
+522	i cant wait to see you! how were the photos were useful? :)	0
+523	great. have a safe trip. dont panic surrender all.	0
+524	shall i start from hear.	0
+525	ok... but bag again..	0
+526	hmm...my uncle just informed me that he's paying the school directly. so pls buy food.	0
+527	there are no other charges after transfer charges and you can withdraw anyhow you like	0
+528	i thought we were doing a king of the hill thing there.	0
+529	then u going ikea str aft dat?	0
+530	boy you best get yo ass out here quick	0
+531	todays voda numbers ending with 7634 are selected to receive a ??350 reward. if you have a match please call 08712300220 quoting claim code 7684 standard rates apply.	1
+532	i'm at work. please call	0
+533	i think your mentor is , but not 100 percent sure.	0
+534	aiyah then i wait lor. then u entertain me. hee...	0
+535	ree entry in 2 a weekly comp for a chance to win an ipod. txt pod to 80182 to get entry (std txt rate) t&c's apply 08452810073 for details 18+	1
+536	ron say fri leh. n he said ding tai feng cant make reservations. but he said wait lor.	0
+537	(you didn't hear it from me)	0
+538	i meant as an apology from me for texting you to get me drugs at  &lt;#&gt; at night	0
+539	i got it before the new year cos yetunde said she wanted to surprise you with it but when i didnt see money i returned it mid january before the  &lt;#&gt; day return period ended.	0
+540	should i send you naughty pix? :)	0
+541	eh den sat u book e kb liao huh...	0
+542	arun can u transfr me d amt	0
+543	you have won a nokia 7250i. this is what you get when you win our free auction. to take part send nokia to 86021 now. hg/suite342/2lands row/w1jhl 16+	1
+544	yes:)here tv is always available in work place..	0
+545	i'm in a movie... collect car oredi...	0
+546	sorry, left phone upstairs. ok, might be hectic but would be all my birds with one fell swoop. it's a date.	0
+547	reply to win ??100 weekly! where will the 2006 fifa world cup be held? send stop to 87239 to end service	1
+548	what time you coming down later?	0
+549	congratulations ur awarded 500 of cd vouchers or 125gift guaranteed & free entry 2 100 wkly draw txt music to 87066 tncs www.ldew.com1win150ppmx3age16	1
+550	so ?_'ll be submitting da project tmr rite?	0
+551	o. well uv causes mutations. sunscreen is like essential thesedays	0
+552	g.w.r	0
+553	she said,'' do u mind if i go into the bedroom for a minute ? '' ''ok'', i sed in a sexy mood. she came out 5 minuts latr wid a cake...n my wife,	0
+554	you please give us connection today itself before  &lt;decimal&gt;  or refund the bill	0
+555	sindu got job in birla soft ..	0
+556	now only i reached home. . . i am very tired now. . i will come tomorro	0
+557	macha dont feel upset.i can assume your mindset.believe me one evening with me and i have some wonderful plans for both of us.let life begin again.call me anytime	0
+558	lol i was gonna last month. i cashed some in but i left &lt;#&gt; just in case. i was collecting more during the week cause they announced it on the blog.	0
+559	\hi its kate it was lovely to see you tonight and ill phone you tomorrow. i got to sing and a guy gave me his card! xxx\""	0
+560	thesmszone.com lets you send free anonymous and masked messages..im sending this message from there..do you see the potential for abuse???	1
+561	you are a winner u have been specially selected 2 receive ??1000 or a 4* holiday (flights inc) speak to a live operator 2 claim 0871277810910p/min (18+)	1
+562	will u meet ur dream partner soon? is ur career off 2 a flyng start? 2 find out free, txt horo followed by ur star sign, e. g. horo aries	1
+563	friendship poem: dear o dear u r not near but i can hear dont get fear live with cheer no more tear u r always my dear. gud ni8	0
+564	thanks honey but still haven't heard anything i will leave it a bit longer so not 2 crowd him and will try later - great advice thanks hope cardiff is still there!	0
+565	lol alright i was thinkin that too haha	0
+566	cant think of anyone with * spare room off * top of my head	0
+567	1 in cbe. 2 in chennai.	0
+568	i not free today i haf 2 pick my parents up tonite...	0
+569	hai ana tomarrow am coming on morning.  &lt;decimal&gt;  ill be there in sathy then we ll go to rto office. reply me after came to home.	0
+570	i thought i'd get him a watch, just cos thats the kind of thing u get4an18th. and he loves so much!	0
+571	i am not sure about night menu. . . i know only about noon menu	0
+572	save money on wedding lingerie at www.bridal.petticoatdreams.co.uk choose from a superb selection with national delivery. brought to you by weddingfriend	1
+573	if e timing can, then i go w u lor...	0
+574	i can do that! i want to please you both inside and outside the bedroom...	0
+575	die... now i have e toot fringe again...	0
+576	then ?_ come n pick me at 530 ar?	0
+577	wat time ?_ finish?	0
+578	are you available for soiree on june 3rd?	0
+579	thankyou so much for the call. i appreciate your care.	0
+580	pls confirm the time to collect the cheque.	0
+581	if you r @ home then come down within 5 min	0
+582	sorry im getting up now, feel really bad- totally rejected that kinda me thing.	0
+583	good afternoon, my love. it was good to see your words on ym and get your tm. very smart move, my slave ... *smiles* ... i drink my coffee and await you.	0
+584	ok. she'll be ok. i guess	0
+585	yo! howz u? girls never rang after india. l	0
+586	i have no idea where you are	0
+587	final chance! claim ur ??150 worth of discount vouchers today! text yes to 85023 now! savamob, member offers mobile! t cs savamob pobox84, m263uz. ??3.00 subs 16	1
+588	dating:i have had two of these. only started after i sent a text to talk sport radio last week. any connection do you think or coincidence?	1
+589	god asked, \what is forgiveness?\" a little child gave lovely reply	0
+590	let ur heart be ur compass ur mind ur map ur soul ur guide and u will never loose in world....gnun - sent via way2sms.com	0
+591	alrite jod hows the revision goin? keris bin doin a smidgin. n e way u wanna cum over after college?xx	0
+592	ok. can be later showing around 8-8:30 if you want + cld have drink before. wld prefer not to spend money on nosh if you don't mind, as doing that nxt wk.	0
+593	i hope you arnt pissed off but id would really like to see you tomorrow. love me xxxxxxxxxxxxxx	0
+594	no. thank you. you've been wonderful	0
+595	sorry da..today i wont come to play..i have driving clas..	0
+596	oh ya... got hip hop open. haha i was thinking can go for jazz then zoom to cine... actually tonight i'm free leh... and there's a kb lesson tonight	0
+597	where wuld i be without my baby? the thought alone mite break me and i don??t wanna go crazy but everyboy needs his lady xxxxxxxx	0
+598	was just about to ask. will keep this one. maybe that's why you didn't get all the messages we sent you on glo	0
+599	great. hope you are using your connections from mode men also cos you can never know why old friends can lead you to today	0
+600	sorry, i'll call later in meeting.	0
+601	ok., is any problem to u frm him? wats matter?	0
+602	hey. you got any mail?	0
+603	i'm glad. you are following your dreams.	0
+604	lord of the rings:return of the king in store now!reply lotr by 2 june 4 chance 2 win lotr soundtrack cds stdtxtrate. reply stop to end txts	1
+605	dunno lei shd b driving lor cos i go sch 1 hr oni.	0
+606	it's cool, let me know before it kicks off around  &lt;#&gt; , i'll be out and about all day	0
+607	do you want 750 anytime any network mins 150 text and a new video phone for only five pounds per week call 08000776320 now or reply for delivery tomorrow	1
+608	gettin rdy to ship comp	0
+609	oh yeah i forgot. u can only take 2 out shopping at once.	0
+610	god bless.get good sleep my dear...i will pray!	0
+611	update your face book status frequently :)	0
+612	my uncles in atlanta. wish you guys a great semester.	0
+613	i will be outside office take all from there	0
+614	so is there anything specific i should be doing with regards to jaklin or what because idk what the fuck	0
+615	sorry man, accidentally left my phone on silent last night and didn't check it til i got up	0
+616	please tell me you have some of that special stock you were talking about	0
+617	go chase after her and run her over while she's crossing the street	0
+618	i cant talk to you now.i will call when i can.dont keep calling.	0
+619	babe, have you got enough money to pick up bread and milk ? and i'll give you it back when you get home ?	0
+620	somewhr someone is surely made 4 u. and god has decided a perfect time to make u meet dat person. . . . till den, . . . . . enjoy ur crushes..!!!;-)	0
+621	how many times i told in the stage all use to laugh. you not listen aha.	0
+622	but if she.s drinkin i'm ok.	0
+623	aiyo cos i sms ?_ then ?_ neva reply so i wait 4 ?_ to reply lar. i tot ?_ havent finish ur lab wat.	0
+624	nope. since ayo travelled, he has forgotten his guy	0
+625	tell your friends what you plan to do on valentines day @ &lt;url&gt;	0
+626	printer is cool. i mean groovy. wine is groovying	0
+627	yes i will be there. glad you made it.	0
+628	sorry, i'll call later in meeting.	0
+629	you are right. meanwhile how's project twins comin up	0
+630	then i buy.	0
+631	message:some text missing* sender:name missing* *number missing *sent:date missing *missing u a lot thats y everything is missing sent via fullonsms.com	0
+632	wot u wanna do then missy?	0
+633	but i juz remembered i gotta bathe my dog today..	0
+634	thanx 4 puttin da fone down on me!!	0
+635	i've got ten bucks, jay is being noncomittal	0
+636	\hey j! r u feeling any better	0
+637	its like that hotel dusk game i think. you solve puzzles in a area thing	0
+638	i will send them to your email. do you mind  &lt;#&gt;  times per night?	0
+639	we r outside already.	0
+640	don't think about \what u have got\" think about \"how to use it that you have got\" good ni8"	0
+641	you dont know you jabo me abi.	0
+642	stop knowing me so well!	0
+643	this is a long fuckin showr	0
+644	lol they don't know about my awesome phone. i could click delete right now if i want.	0
+645	the guy at the car shop who was flirting with me got my phone number from the paperwork and called and texted me. i'm nervous because of course now he may have my address. should i call his boss and tell him, knowing this may get him fired?	0
+646	party's at my place at usf, no charge (but if you can contribute in any way it is greatly appreciated) and yeah, we got room for one more	0
+647	i wait 4 ?_ inside da car park...	0
+648	tell my  bad character which u dnt lik in me. i'll try to change in  &lt;#&gt; . i ll add tat 2 my new year resolution. waiting for ur reply.be frank...good morning.	0
+649	or i guess  &lt;#&gt;  min	0
+650	i can call in  &lt;#&gt;  min if thats ok	0
+651	you have 1 new message. please call 08712400200.	1
+652	y so late but i need to go n get da laptop...	0
+653	mm so you asked me not to call radio	0
+654	are you coming to day for class.	0
+655	s..antha num corrct dane	0
+656	c movie is juz last minute decision mah. juz watch 2 lar but i tot ?_ not interested.	0
+657	fine. do you remember me.	0
+658	s now only i took tablets . reaction morning only.	0
+659	how much for an eighth?	0
+660	5p 4 alfie moon's children in need song on ur mob. tell ur m8s. txt tone charity to 8007 for nokias or poly charity for polys: zed 08701417012 profit 2 charity.	1
+661	ladies first and genus second k .	0
+662	mum, i've sent you many many messages since i got here. i just want to know that you are actually getting them. do enjoy the rest of your day.	0
+663	sorry,  in meeting i'll call you later	0
+664	hmv bonus special 500 pounds of genuine hmv vouchers to be won. just answer 4 easy questions. play now! send hmv to 86688 more info:www.100percent-real.com	1
+665	nope watching tv at home... not going out. v bored...	0
+666	tomarrow i want to got to court. at  &lt;decimal&gt; . so you come to bus stand at 9.	0
+667	not yet chikku..wat abt u?	0
+668	one of best dialogue in cute reltnship..!! \wen i die	0
+669	free game. get rayman golf 4 free from the o2 games arcade. 1st get ur games settings. reply post, then save & activ8. press 0 key for arcade. termsapply	1
+670	happy birthday... may all ur dreams come true...	0
+671	nope... juz off from work...	0
+672	dun b sad.. it's over.. dun thk abt it already. concentrate on ur other papers k.	0
+673	hi dude hw r u da realy mising u today	0
+674	[??_] anyway, many good evenings to u! s	0
+675	nice.nice.how is it working?	0
+676	get ready to put on your excellent sub face :)	0
+677	no da. i am happy that we sit together na	0
+678	if i not meeting ?_ all rite then i'll go home lor. if ?_ dun feel like comin it's ok.	0
+679	ultimately tor motive tui achieve korli.	0
+680	xmas prize draws! we are trying to contact u. todays draw shows that you have won a ??2000 prize guaranteed. call 09058094565 from land line. valid 12hrs only	1
+681	\boo babe! u enjoyin yourjob? u seemed 2 b gettin on well hunny!hope ure ok?take care & i??llspeak 2u soonlots of loveme xxxx.\""	0
+682	that one week leave i put know that time. why.	0
+683	so anyways, you can just go to your gym or whatever, my love *smiles* i hope your ok and having a good day babe ... i miss you so much already	0
+684	i am seeking a lady in the street and a freak in the sheets. is that you?	0
+685	yeah i think my usual guy's still passed out from last night, if you get ahold of anybody let me know and i'll throw down	0
+686	would me smoking you out help us work through this difficult time	0
+687	its ok, called mom instead have fun	0
+688	no sir. that's why i had an 8-hr trip on the bus last week. have another audition next wednesday but i think i might drive this time.	0
+689	it means u could not keep ur words.	0
+690	i dont thnk its a wrong calling between us	0
+691	i was just callin to say hi. take care bruv!	0
+692	no chikku nt yet.. ya i'm free	0
+693	oh... haha... den we shld had went today too... gee, nvm la... kaiez, i dun mind goin jazz oso... scared hiphop open cant catch up...	0
+694	lol! oops sorry! have fun.	0
+695	&lt;#&gt; %of pple marry with their lovers... becz they hav gud undrstndng dat avoids problems. i sent dis 2 u, u wil get gud news on friday by d person you like. and tomorrow will be the best day of your life. dont break this chain. if you break you will suffer. send this to  &lt;#&gt;  frnds in &lt;#&gt;  mins whn u read...	0
+696	wow ... i love you sooo much, you know ? i can barely stand it ! i wonder how your day goes and if you are well, my love ... i think of you and miss you	0
+697	how i noe... she's in da car now... later then c lar... i'm wearing shorts...	0
+698	lmao ok i wont be needing u to do my hair anymore.	0
+699	hey what are you doing. y no reply pa..	0
+700	when did you get to the library	0
+701	yeah, we got one lined up for us	0
+702	have you heard about that job? i'm going to that wildlife talk again tonight if u want2come. its that2worzels and a wizzle or whatever it is?!	0
+703	i want to tel u one thing u should not mistake me k this is the message that you sent:)	0
+704	i think its far more than that but find out. check google maps for a place from your dorm.	0
+705	all e best 4 ur driving tmr :-)	0
+706	i got another job! the one at the hospital doing data analysis or something, starts on monday! not sure when my thesis will got finished	0
+707	god's love has no limit. god's grace has no measure. god's power has no boundaries. may u have god's endless blessings always in ur life...!! gud ni8	0
+708	thank god they are in bed!	0
+709	text & meet someone sexy today. u can find a date or even flirt its up to u. join 4 just 10p. reply with name & age eg sam 25. 18 -msg recd@thirtyeight pence	1
+710	enjoy urself tmr...	0
+711	how is your schedule next week? i am out of town this weekend.	0
+712	good morning pookie pie! lol hope i didn't wake u up	0
+713	only 2% students solved this cat question in 'xam... 5+3+2= &lt;#&gt;  9+2+4= &lt;#&gt;  8+6+3= &lt;#&gt;  then 7+2+5=????? tell me the answer if u r brilliant...1thing.i got d answr.	0
+714	and whenever you and i see we can still hook up too.	0
+715	cuz ibored. and don wanna study	0
+716	thanx...	0
+717	hello- thanx for taking that call. i got a job! starts on monday!	0
+718	* thought i didn't see you.	0
+719	sir goodmorning, once free call me.	0
+720	just send a text. we'll skype later.	0
+721	good afternoon on this glorious anniversary day, my sweet j !! i hope this finds you happy and content, my prey. i think of you and send a teasing kiss from across the sea coaxing images of fond souveniers ... you cougar-pen	0
+722	how much u trying to get?	0
+723	yup i thk so until e shop closes lor.	0
+724	hi. customer loyalty offer:the new nokia6650 mobile from only ??10 at txtauction! txt word: start to no: 81151 & get yours now! 4t&ctxt tc 150p/mtmsg	1
+725	hi there. we have now moved in2 our pub . would be great 2 c u if u cud come up.	0
+726	good morning, my love ... i go to sleep now and wish you a great day full of feeling better and opportunity ... you are my last thought babe, i love you *kiss*	0
+727	guy, no flash me now. if you go call me, call me. how madam. take care oh.	0
+728	can come my room but cannot come my house cos my house still messy... haha...	0
+729	today's offer! claim ur ??150 worth of discount vouchers! text yes to 85023 now! savamob, member offers mobile! t cs 08717898035. ??3.00 sub. 16 . unsub reply x	1
+730	not heard from u4 a while. call 4 rude chat private line 01223585334 to cum. wan 2c pics of me gettin shagged then text pix to 8552. 2end send stop 8552 sam xxx	1
+731	i'm fine. hope you are good. do take care.	0
+732	neither [in sterm voice] - i'm studying. all fine with me! not sure the  thing will be resolved, tho. anyway. have a fab hols	0
+733	congrats kano..whr s the treat maga?	0
+734	i'm gonna say no. sorry. i would but as normal am starting to panic about time. sorry again! are you seeing on tuesday?	0
+735	pls accept me for one day. or am begging you change the number.	0
+736	idk. you keep saying that you're not, but since he moved, we keep butting heads over freedom vs. responsibility. and i'm tired. i have so much other shit to deal with that i'm barely keeping myself together once this gets added to it.	0
+737	i also thk too fast... xy suggest one not me. u dun wan it's ok. going 2 rain leh where got gd.	0
+738	i not busy juz dun wan 2 go so early.. hee..	0
+739	i accidentally brought em home in the box	0
+740	me n him so funny...	0
+741	perhaps * is much easy give your account identification, so i will tomorrow at uni	0
+742	yeah hopefully, if tyler can't do it i could maybe ask around a bit	0
+743	how come guoyang go n tell her? then u told her?	0
+744	finish liao... u?	0
+745	what will we do in the shower, baby?	0
+746	the monthly amount is not that terrible and you will not pay anything till 6months after finishing school.	0
+747	a bit of ur smile is my hppnss, a drop of ur tear is my sorrow, a part of ur heart is my life, a heart like mine wil care for u, forevr as my goodfriend	0
+748	now got tv 2 watch meh? u no work today?	0
+749	nvm take ur time.	0
+750	geeeee ... your internet is really bad today, eh ?	0
+751	you should know now. so how's anthony. are you bringing money. i've school fees to pay and rent and stuff like that. thats why i need your help. a friend in need....|	0
+752	xmas & new years eve tickets are now on sale from the club, during the day from 10am till 8pm, and on thurs, fri & sat night this week. they're selling fast!	1
+753	message important information for o2 user. today is your lucky day! 2 find out why log onto http://www.urawinner.com there is a fantastic surprise awaiting you	1
+754	every monday..nxt week vl be completing..	0
+755	\oh fuck. juswoke up in a bed on a boatin the docks. slept wid 25 year old. spinout! giv u da gossip l8r. xxx\""	0
+756	i can't right this second, gotta hit people up first	0
+757	hurry home u big butt. hang up on your last caller if u have to. food is done and i'm starving. don't ask what i cooked.	0
+758	hey anyway i have to :-)	0
+759	crazy ar he's married. ?? like gd looking guys not me. my frens like say he's korean leona's fave but i dun thk he is. aft some thinking mayb most prob i'll go.	0
+760	you have won! as a valued vodafone customer our computer has picked you to win a ??150 prize. to collect is easy. just call 09061743386	1
+761	pls go ahead with watts. i just wanted to be sure. do have a great weekend. abiola	0
+762	do you know why god created gap between your fingers..? so that, one who is made for you comes &amp; fills those gaps by holding your hand with love..!	0
+763	u should make a fb list	0
+764	ok good then i later come find ?_... c lucky i told ?_ to go earlier... later pple take finish ?_ no more again...	0
+765	lol please do. actually send a pic of yourself right now. i wanna see. pose with a comb and hair dryer or something.	0
+766	\si.como no?!listened2the plaid album-quite gd&the new air1 which is hilarious-also bought??braindance??a comp.ofstuff on aphex??s ;abel	0
+767	o turns out i had stereo love on mi phone under the unknown album.	0
+768	i will be gentle baby! soon you will be taking all  &lt;#&gt;  inches deep inside your tight pussy...	0
+769	then its most likely called mittelschmertz. google it. if you dont have paracetamol dont worry it will go.	0
+770	hi darlin ive just got back and i had a really nice night and thanks so much for the lift see u tomorrow xxx	0
+771	sorry i flaked last night, shit's seriously goin down with my roommate, what you up to tonight?	0
+772	ok...	0
+773	you are not bothering me but you have to trust my answers. pls.	0
+774	im gonnamissu so much!!i would say il send u a postcard buttheres aboutas much chance of merememberin asthere is ofsi not breakin his contract!! luv yaxx	0
+775	yeah no probs - last night is obviously catching up with you... speak soon	0
+776	don't necessarily expect it to be done before you get back though because i'm just now headin out	0
+777	i don't run away frm u... i walk slowly &amp; it kills me that u don't care enough to stop me...	0
+778	i'm back, lemme know when you're ready	0
+779	watching tv now. i got new job :)	0
+780	erm... woodland avenue somewhere. do you get the parish magazine, his telephone number will be in there.	0
+781	i just lov this line: \hurt me with the truth i wil tolerat.bcs ur my someone..... but never comfort me with a lie\" gud ni8 and sweet dreams"	0
+782	that's my honeymoon outfit. :)	0
+783	my love ! how come it took you so long to leave for zaher's? i got your words on ym and was happy to see them but was sad you had left. i miss you	0
+784	send to someone else :-)	0
+785	k do i need a login or anything	0
+786	urgent!! your 4* costa del sol holiday or ??5000 await collection. call 09050090044 now toclaim. sae, tc s, pobox334, stockport, sk38xh, cost??1.50/pm, max10mins	1
+787	are you going to write ccna exam this week??	0
+788	ill call u 2mrw at ninish, with my address that icky american freek wont stop callin me 2 bad jen k eh?	0
+789	i agree. so i can stop thinkin about ipad. can you please ask macho the same question.	0
+790	actually i'm waiting for 2 weeks when they start putting ad.	0
+791	wat ?_ doing now?	0
+792	same here, but i consider walls and bunkers and shit important just because i never play on peaceful but i guess your place is high enough that it don't matter	0
+793	my darling sister. how are you doing. when's school resuming. is there a minimum wait period before you reapply? do take care	0
+794	that's necessarily respectful	0
+795	nice talking to you! please dont forget my pix :) i want to see all of you...	0
+796	sounds like a plan! cardiff is still here and still cold! i'm sitting on the radiator!	0
+797	this msg is for your mobile content order it has been resent as previous attempt failed due to network error queries to customersqueries@netvision.uk.com	1
+798	i'll be late...	0
+799	or ?_ go buy wif him then i meet ?_ later can?	0
+800	nokia phone is lovly..	0
+801	i met you as a stranger and choose you as my friend. as long as the world stands, our friendship never ends. lets be friends forever!!! gud nitz...	0
+802	shall i come to get pickle	0
+803	i???m going to try for 2 months ha ha only joking	0
+804	solve d case : a man was found murdered on  &lt;decimal&gt; . &lt;#&gt;  afternoon. 1,his wife called police. 2,police questioned everyone. 3,wife: sir,i was sleeping, when the murder took place. 4.cook: i was cooking. 5.gardener: i was picking vegetables. 6.house-maid: i went 2 d post office. 7.children: we went 2 play. 8.neighbour: we went 2 a marriage. police arrested d murderer immediately. who's it? reply with reason, if u r brilliant.	0
+805	do u still have plumbers tape and a wrench we could borrow?	0
+806	sorry, i'll call later	0
+807	nice. wait...should you be texting right now? i'm not gonna pay your ticket, ya know!	0
+808	just looked it up and addie goes back monday, sucks to be her	0
+809	did you say bold, then torch later. or one torch and 2bold?	0
+810	watching cartoon, listening music &amp; at eve had to go temple &amp; church.. what about u?	0
+811	hello, my boytoy! i made it home and my constant thought is of you, my love. i hope your having a nice visit but i can't wait till you come home to me ...*kiss*	0
+812	hey gals...u all wanna meet 4 dinner at n??te?	0
+813	hey what happen de. are you alright.	0
+814	congrats! 2 mobile 3g videophones r yours. call 09061744553 now! videochat wid ur mates, play java games, dload polyh music, noline rentl. bx420. ip4. 5we. 150pm	1
+815	dont search love, let love find u. thats why its called falling in love, bcoz u dont force yourself, u just fall and u know there is smeone to hold u... bslvyl	0
+816	nope wif my sis lor... aft bathing my dog then i can bathe... looks like it's going 2 rain soon.	0
+817	i av a new number,  . wil u only use this one,ta.	0
+818	hmm thinking lor...	0
+819	tessy..pls do me a favor. pls convey my birthday wishes to nimya..pls dnt forget it. today is her birthday shijas	0
+820	ringtoneking 84484	1
+821	i cant pick the phone right now. pls send a message	0
+822	ok. me watching tv too.	0
+823	thanks love. but am i doing torch or bold.	0
+824	yup i've finished c ?_ there...	0
+825	\hello u.call wen u finish wrk.i fancy meetin up wiv u all tonite as i need a break from dabooks. did 4 hrs last nite+2 today of wrk!\""	0
+826	hey morning what you come to ask:-) pa...	0
+827	spjanuary male sale! hot gay chat now cheaper, call 08709222922. national rate from 1.5p/min cheap to 7.8p/min peak! to stop texts call 08712460324 (10p/min)	1
+828	i don't know but i'm raping dudes at poker	0
+829	this message is free. welcome to the new & improved sex & dogging club! to unsubscribe from this service reply stop. msgs@150p 18+only	1
+830	have you been practising your curtsey?	0
+831	juz now havent woke up so a bit blur blur... can? dad went out liao... i cant cum now oso...	0
+832	urgent! your mobile no 077xxx won a ??2,000 bonus caller prize on 02/06/03! this is the 2nd attempt to reach you! call 09066362206 asap! box97n7qp, 150ppm	1
+833	what time. i???m out until prob 3 or so	0
+834	the xmas story is peace.. the xmas msg is love.. the xmas miracle is jesus.. hav a blessed month ahead &amp; wish u merry xmas...	0
+835	i wnt to buy a bmw car urgently..its vry urgent.but hv a shortage of  &lt;#&gt; lacs.there is no source to arng dis amt. &lt;#&gt; lacs..thats my prob	0
+836	wat time liao, where still got.	0
+837	ok pa. nothing problem:-)	0
+838	dear dave this is your final notice to collect your 4* tenerife holiday or #5000 cash award! call 09061743806 from landline. tcs sae box326 cw25wx 150ppm	1
+839	what do u want when i come back?.a beautiful necklace as a token of my heart for you.thats what i will give but only to my wife of my liking.be that and see..no one can give you that.dont call me.i will wait till i come.	0
+840	i bought the test yesterday. its something that lets you know the exact day u ovulate.when will get 2u in about 2 to 3wks. but pls pls dont fret. i know u r worried. pls relax. also is there anything in ur past history u need to tell me?	0
+841	look at the fuckin time. what the fuck you think is up	0
+842	merry christmas to u too annie!	0
+843	as per your request 'melle melle (oru minnaminunginte nurungu vettam)' has been set as your callertune for all callers. press *9 to copy your friends callertune	0
+844	sorry, got a late start, we're on the way	0
+845	what class of  &lt;#&gt;  reunion?	0
+846	poor girl can't go one day lmao	0
+847	there is os called ubandu which will run without installing in hard disk...you can use that os to copy the important files in system and give it to repair shop..	0
+848	better than bb. if he wont use it, his wife will or them doctor	0
+849	the hair cream has not been shipped.	0
+850	hi its me you are probably having too much fun to get this message but i thought id txt u cos im bored! and james has been farting at me all night	0
+851	correct. so how was work today	0
+852	awesome, i'll see you in a bit	0
+853	k.k:)apo k.good movie.	0
+854	i realise you are a busy guy and i'm trying not to be a bother. i have to get some exams outta the way and then try the cars. do have a gr8 day	0
+855	you still around? looking to pick up later	0
+856	ok ill send you with in  &lt;decimal&gt;  ok.	0
+857	yes i have. so that's why u texted. pshew...missing you so much	0
+858	two fundamentals of cool life: \walk whoever is the king\"!... gud nyt"	0
+859	well at this right i'm gonna have to get up and check today's steam sales/pee so text me when you want me to come get you	0
+860	ok how you dear. did you call chechi	0
+861	sac will score big hundred.he is set batsman:-)	0
+862	bored of speed dating? try speedchat, txt speedchat to 80155, if you don't like em txt swap and get a new chatter! chat80155 pobox36504w45wq 150p/msg rcd 16	1
+863	nt yet chikku..simple habba..hw abt u?	0
+864	xy trying smth now. u eat already? we havent...	0
+865	do you know where my lab goggles went	0
+866	its a part of checking iq	0
+867	or u ask they all if next sat can a not. if all of them can make it then i'm ok lor.	0
+868	i don't know u and u don't know me. send chat to 86688 now and let's find each other! only 150p/msg rcvd. hg/suite342/2lands/row/w1j6hl ldn. 18 years or over.	1
+869	ok i thk i got it. then u wan me 2 come now or wat?	0
+870	thanks for your ringtone order, reference number x29. your mobile will be charged 4.50. should your tone not arrive please call customer services 09065989180	1
+871	\hello-/@drivby-:0quit edrunk sorry iff pthis makes no senrd-dnot no how ^ dancce 2 drum n basq!ihave fun 2nhite x ros xxxxxxx\""	0
+872	today is sorry day.! if ever i was angry with you, if ever i misbehaved or hurt you? plz plz just slap urself bcoz, its ur fault, i'm basically good	0
+873	did you stitch his trouser	0
+874	then u ask darren go n pick u lor... but i oso sian tmr haf 2 meet lect...	0
+875	i told that am coming on wednesday.	0
+876	hmmm:)how many players selected?	0
+877	who u talking about?	0
+878	ok... sweet dreams...	0
+879	actually fuck that, just do whatever, do find an excuse to be in tampa at some point before january though	0
+880	probably gonna swing by in a wee bit	0
+881	yeah if we do have to get a random dude we need to change our info sheets to party  &lt;#&gt; /7 never study just to be safe	0
+882	k so am i, how much for an 8th? fifty?	0
+883	hi kindly give us back our documents which we submitted for loan from stapati	0
+884	you can never do nothing	0
+885	oh yeah! and my diet just flew out the window	0
+886	oh...i asked for fun. haha...take care. ?_	0
+887	gud ni8.swt drms.take care	0
+888	save yourself the stress. if the person has a dorm account, just send your account details and the money will be sent to you.	0
+889	da is good good player.why he is unsold.	0
+890	you might want to pull out more just in case and just plan on not spending it if you can, i don't have much confidence in derek and taylor's money management	0
+891	sitting in mu waiting for everyone to get out of my suite so i can take a shower	0
+892	good afternoon, my love ! any job prospects ? are you missing me ? what do you do ? are you being lazy and bleak, hmmm ? or happy and filled with my love ?	0
+893	i am literally in bed and have been up for like  &lt;#&gt;  hours	0
+894	ok da, i already planned. i wil pick you.	0
+895	oh unintentionally not bad timing. great. fingers  the trains play along! will give fifteen min warning.	0
+896	mm i had my food da from out	0
+897	hot live fantasies call now 08707509020 just 20p per min ntt ltd, po box 1327 croydon cr9 5wb 0870 is a national rate call	1
+898	alright i have a new goal now	0
+899	i sent lanre fakeye's eckankar details to the mail box	0
+900	i'm at work. please call	0
+901	omg if its not one thing its another. my cat has worms :/ when does this bad day end?	0
+902	excellent! are you ready to moan and scream in ecstasy?	0
+903	they did't play one day last year know even though they have very good team.. like india.	0
+904	hui xin is in da lib.	0
+905	i got a call from a landline number. . . i am asked to come to anna nagar . . . i will go in the afternoon	0
+906	i'm going 4 lunch now wif my family then aft dat i go str 2 orchard lor.	0
+907	yo carlos, a few friends are already asking me about you, you working at all this weekend?	0
+908	i'm in a meeting, call me later at	0
+909	got it! it looks scrumptious... daddy wants to eat you all night long!	0
+910	no my blankets are sufficient, thx	0
+911	dear how you. are you ok?	0
+912	that seems unnecessarily hostile	0
+913	single line with a big meaning::::: \miss anything 4 ur \"best life\" but	0
+914	finally it has happened..! aftr decades..! beer is now cheaper than petrol! the goverment expects us to \drink\". . . but don't \"drive \""	0
+915	can you pls pls send me a mail on all you know about relatives coming to deliver here? all you know about costs, risks, benefits and anything else. thanks.	0
+916	in work now. going have in few min.	0
+917	i am not sure about night menu. . . i know only about noon menu	0
+918	she just broke down a list of reasons why nobody's in town and i can't tell if she's being sarcastic or just faggy	0
+919	goodmorning,my grandfather expired..so am on leave today.	0
+920	jade its paul. y didn??t u txt me? do u remember me from barmed? i want 2 talk 2 u! txt me	0
+921	thanks for your subscription to ringtone uk your mobile will be charged ??5/month please confirm by replying yes or no. if you reply no you will not be charged	1
+922	babes i think i got ur brolly i left it in english wil bring it in 2mrw 4 u luv franxx	0
+923	double eviction this week - spiral and michael and good riddance to them!	0
+924	mum ask ?_ to buy food home...	0
+925	hey you around? i've got enough for a half + the ten i owe you	0
+926	just taste fish curry :-p	0
+927	i really need 2 kiss u i miss u my baby from ur baby 4eva	0
+928	re your call; you didn't see my facebook huh?	0
+929	win: we have a winner! mr. t. foley won an ipod! more exciting prizes soon, so keep an eye on ur mobile or visit www.win-82050.co.uk	1
+930	free entry into our ??250 weekly competition just text the word win to 80086 now. 18 t&c www.txttowin.co.uk	1
+931	talk to g and x about that	0
+932	waqt se pehle or naseeb se zyada kisi ko kuch nahi milta,zindgi wo nahi he jo hum sochte hai zindgi wo hai jo ham jeetey hai..........	0
+933	lol what happens in vegas stays in vegas	0
+934	get 3 lions england tone, reply lionm 4 mono or lionp 4 poly. 4 more go 2 www.ringtones.co.uk, the original n best. tones 3gbp network operator rates apply	1
+935	only saturday and sunday holiday so its very difficult:)	0
+936	wait that's still not all that clear, were you not sure about me being sarcastic or that that's why x doesn't want to live with us	0
+937	k..give back my thanks.	0
+938	u calling me right? call my hand phone...	0
+939	six chances to win cash! from 100 to 20,000 pounds txt> csh11 and send to 87575. cost 150p/day, 6days, 16+ tsandcs apply reply hl 4 info	1
+940	yo we are watching a movie on netflix	0
+941	<forwarded from 21870000>hi - this is your mailbox messaging sms alert. you have 40 matches. please call back on 09056242159 to retrieve your messages and matches cc100p/min	1
+942	wot u up 2 u weirdo?	0
+943	thanks  and ! or bomb and date as my phone wanted to say!	0
+944	burger king - wanna play footy at a top stadium? get 2 burger king before 1st sept and go large or super with coca-cola and walk out a winner	1
+945	love it! the girls at the office may wonder why you are smiling but sore...	0
+946	free camera phones with linerental from 4.49/month with 750 cross ntwk mins. 1/2 price txt bundle deals also avble. call 08001950382 or call2optout/j mf	1
+947	night has ended for another day, morning has come in a special way. may you smile like the sunny rays and leaves your worries at the blue blue bay. gud mrng	0
+948	were gonna go get some tacos	0
+949	a gram usually runs like  &lt;#&gt; , a half eighth is smarter though and gets you almost a whole second gram for  &lt;#&gt;	0
+950	all will come alive.better correct any good looking figure there itself..	0
+951	1000's of girls many local 2 u who r virgins 2 this & r ready 2 4fil ur every sexual need. can u 4fil theirs? text cute to 69911(??1.50p. m)	1
+952	imagine life without me... see.. how fast u are searching me?don't worry.. l'm always there to disturb u.. goodnoon..:)	0
+953	i will come tomorrow di	0
+954	i cant pick the phone right now. pls send a message	0
+955	pls dont restrict her from eating anythin she likes for the next two days.	0
+956	no got new job at bar in airport on satsgettin 4.47per hour but means no lie in! keep in touch	0
+957	er mw im filled tuth is aight	0
+958	ok. very good. its all about making that money.	0
+959	come aftr  &lt;decimal&gt; ..now i m cleaning the house	0
+960	teach me apps da. when you come to college.	0
+961	howz that persons story	0
+962	i attended but nothing is there.	0
+963	lol i would but my mom would have a fit and tell the whole family how crazy and terrible i am	0
+964	email alertfrom: jeri stewartsize: 2kbsubject: low-cost prescripiton drvgsto listen to email call 123	1
+965	free for 1st week! no1 nokia tone 4 ur mob every week just txt nokia to 8007 get txting and tell ur mates www.getzed.co.uk pobox 36504 w45wq norm150p/tone 16+	1
+966	yo im right by yo work	0
+967	princess, i like to make love  &lt;#&gt;  times per night. hope thats not a problem!	0
+968	?? wait 4 me in sch i finish ard 5..	0
+969	we tried to call you re your reply to our sms for a video mobile 750 mins unlimited text free camcorder reply or call now 08000930705 del thurs	1
+970	i just really need shit before tomorrow and i know you won't be awake before like 6	0
+971	anything lar...	0
+972	raji..pls do me a favour. pls convey my birthday wishes to nimya. pls. today is her birthday.	0
+973	u reach orchard already? u wan 2 go buy tickets first?	0
+974	hi its in durban are you still on this number	0
+975	hi. wk been ok - on hols now! yes on for a bit of a run. forgot that i have hairdressers appointment at four so need to get home n shower beforehand. does that cause prob for u?	0
+976	i have a sore throat. it's scratches when i talk	0
+977	the message sent is askin for  &lt;#&gt; dollars. shoul i pay  &lt;#&gt;  or  &lt;#&gt; ?	0
+978	hi, can i please get a  &lt;#&gt;  dollar loan from you. i.ll pay you back by mid february. pls.	0
+979	if he started searching he will get job in few days.he have great potential and talent.	0
+980	what are your new years plans?	0
+981	s...from the training manual it show there is no tech process:)its all about password reset and troubleshooting:)	0
+982	well, i'm glad you didn't find it totally disagreeable ... lol	0
+983	i can make lasagna for you... vodka...	0
+984	hey darlin.. i can pick u up at college if u tell me wen & where 2 mt.. love pete xx	0
+985	i have 2 docs appointments next week.:/ i'm tired of them shoving stuff up me. ugh why couldn't i have had a normal body?	0
+986	you have won a guaranteed ??1000 cash or a ??2000 prize. to claim yr prize call our customer service representative on 08714712394 between 10am-7pm	1
+987	??_ and don???t worry we???ll have finished by march ??_ ish!	0
+988	sms services. for your inclusive text credits, pls goto www.comuk.net login= 3qxj9 unsubscribe with stop, no extra charge. help 08702840625.comuk. 220-cm2 9ae	1
+989	thx. all will be well in a few months	0
+990	a guy who gets used but is too dumb to realize it.	0
+991	:)	0
+992	i luv u soo much u don??t understand how special u r 2 me ring u 2morrow luv u xxx	0
+993	the  &lt;#&gt; g that i saw a few days ago, the guy wants sell wifi only for  &lt;#&gt;  and with 3g for  &lt;#&gt; . that's why i blanked him.	0
+994	hello my boytoy ... geeee i miss you already and i just woke up. i wish you were here in bed with me, cuddling me. i love you ...	0
+995	december only! had your mobile 11mths+? you are entitled to update to the latest colour camera mobile for free! call the mobile update co free on 08002986906	1
+996	lol. well quality aint bad at all so i aint complaining	0
+997	you aren't coming home between class, right? i need to work out and shower!	0
+998	i dont want to hear anything	0
+999	dad wanted to talk about the apartment so i got a late start, omw now	0
+1000	u coming 2 pick me?	0
+1001	hey they r not watching movie tonight so i'll prob b home early...	0
+1002	nvm it's ok...	0
+1003	4mths half price orange line rental & latest camera phones 4 free. had your phone 11mths ? call mobilesdirect free on 08000938767 to update now! or2stoptxt	1
+1004	congrats. that's great. i wanted to tell you not to tell me your score cos it might make me relax. but its motivating me so thanks for sharing	0
+1005	please don't text me anymore. i have nothing else to say.	0
+1006	yar lor actually we quite fast... cos da ge slow wat... haha...	0
+1007	yeah, i'll leave in a couple minutes &amp; let you know when i get to mu	0
+1008	lookatme!: thanks for your purchase of a video clip from lookatme!, you've been charged 35p. think you can do better? why not send a video in a mmsto 32323.	1
+1009	so u workin overtime nigpun?	0
+1010	have a good evening! ttyl	0
+1011	dad went out oredi...	0
+1012	wat would u like 4 ur birthday?	0
+1013	go fool dont cheat others ok	0
+1014	it so happens that there r 2waxsto do wat you want. she can come and ill get her medical insurance. and she'll be able to deliver and have basic care. i'm currently shopping for the right medical insurance for her. so just give me til friday morning. thats when i.ll see the major person that can guide me to the right insurance.	0
+1015	from next month get upto 50% more calls 4 ur standard network charge 2 activate call 9061100010 c wire3.net 1st4terms pobox84 m26 3uz cost ??1.50 min mobcudb more	1
+1016	never blame a day in ur life. good days give u happiness. bad days give u experience. both are essential in life! all are gods blessings! good morning.:	0
+1017	on ma way to school. can you pls send me ashley's number	0
+1018	he says hi and to get your ass back to south tampa (preferably at a kegger)	0
+1019	not a drop in the tank	0
+1020	hi dis is yijue i would be happy to work wif ?_ all for gek1510...	0
+1021	take care and sleep well.you need to learn to change in life.you only need to get convinced on that.i will wait but no more conversations between us.get convinced by that time.your family is over for you in many senses.respect them but not overemphasise.or u have no role in my life.	0
+1022	going to join tomorrow.	0
+1023	oops my phone died and i didn't even know. yeah i like it better.	0
+1024	k, jason says he's gonna be around so i'll be up there around  &lt;#&gt;	0
+1025	did you see that film:)	0
+1026	ok try to do week end course in coimbatore.	0
+1027	yun ah.the ubi one say if ?_ wan call by tomorrow.call 67441233 look for irene.ere only got bus8,22,65,61,66,382. ubi cres,ubi tech park.6ph for 1st 5wkg days.??n	0
+1028	that seems unnecessarily affectionate	0
+1029	i.ll always be there, even if its just in spirit. i.ll get a bb soon. just trying to be sure i need it.	0
+1030	havent still waitin as usual... ?? come back sch oredi?	0
+1031	lol ... i knew that .... i saw him in the dollar store	0
+1032	are you staying in town ?	0
+1033	why did i wake up on my own &gt;:(	0
+1034	is ur paper today in e morn or aft?	0
+1035	private! your 2003 account statement for shows 800 un-redeemed s. i. m. points. call 08718738002 identifier code: 48922 expires 21/11/04	1
+1036	ok lor... but buy wat?	0
+1037	well obviously not because all the people in my cool college life went home ;_;	0
+1038	wrong phone! this phone! i answer this one but assume the other is people i don't well	0
+1039	no prob. i will send to your email.	0
+1040	freemsg hi baby wow just got a new cam moby. wanna c a hot pic? or fancy a chat?im w8in 4utxt / rply chat to 82242 hlp 08712317606 msg150p 2rcv	1
+1041	naughty little thought: 'its better to flirt, flirt n flirt, rather than loving someone n gettin hurt, hurt n hurt...:-) gud nyt	0
+1042	u know we watchin at lido?	0
+1043	why do you ask princess?	0
+1044	solve d case : a man was found murdered on  &lt;decimal&gt; . &lt;#&gt;  afternoon. 1,his wife called police. 2,police questioned everyone. 3,wife: sir,i was sleeping, when the murder took place. 4.cook: i was cooking. 5.gardener: i was picking vegetables. 6.house-maid: i went 2 d post office. 7.children: we went 2 play. 8.neighbour: we went 2 a marriage. police arrested d murderer immediately. who's it? reply with reason, if u r brilliant.	0
+1045	500 new mobiles from 2004, must go! txt: nokia to no: 89545 & collect yours today!from only ??1 www.4-tc.biz 2optout 087187262701.50gbp/mtmsg18 txtauction	1
+1046	aight will do, thanks again for comin out	0
+1047	gudnite....tc...practice going on	0
+1048	sorry im stil fucked after last nite went tobed at 430 got up 4 work at 630	0
+1049	i'm there and i can see you, but you can't see me ? maybe you should reboot ym ? i seen the buzz	0
+1050	hows the champ just leaving glasgow!	0
+1051	u have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09058094565	1
+1052	that's a shame! maybe cld meet for few hrs tomo?	0
+1053	i am not having her number sir	0
+1054	a ??400 xmas reward is waiting for you! our computer has randomly picked you from our loyal mobile customers to receive a ??400 reward. just call 09066380611	1
+1055	yalru lyfu astne chikku.. bt innu mundhe lyf ali halla ke bilo (marriage)program edhae, so lyf is nt yet ovr chikku..ali vargu lyfu meow meow:-d	0
+1056	evening * v good if somewhat event laden. will fill you in, don't you worry ??_ head * ok but throat * wrecked. see you at six then!	0
+1057	which is weird because i know i had it at one point	0
+1058	oh yeah clearly it's my fault	0
+1059	lol well don't do it without me. we could have a big sale together.	0
+1060	i keep ten rs in my shelf:) buy two egg.	0
+1061	* free* polyphonic ringtone text super to 87131 to get your free poly tone of the week now! 16 sn pobox202 nr31 7zs subscription 450pw	1
+1062	can you please send me my aunty's number	0
+1063	\life is nothing wen v get everything\". but \"life is everything wen v miss something \". real value of people wil be realized only in their absence.... gud mrng"	0
+1064	today i'm not workin but not free oso... gee... thgt u workin at ur fren's shop ?	0
+1065	glad to see your reply.	0
+1066	but i'll b going 2 sch on mon. my sis need 2 take smth.	0
+1067	no, i *didn't* mean to post it. i wrote it, and like so many other times i've ritten stuff to you, i let it sit there. it was what i was feeling at the time. i was angry. before i left, i hit send, then stop. it wasn't there. i checked on my phone when i got to my car. it wasn't there. you said you didn't sleep, you were bored. so why wouldn't that be the time to clean, fold laundry, etc.? at least make the bed?	0
+1068	i am at the gas station. go there.	0
+1069	urgent! your mobile number has been awarded with a ??2000 prize guaranteed. call 09058094454 from land line. claim 3030. valid 12hrs only	1
+1070	hey you told your name to gautham ah?	0
+1071	\pete can you please ring meive hardly gotany credit\""	0
+1072	u sick still can go shopping?	0
+1073	no 1 polyphonic tone 4 ur mob every week! just txt pt2 to 87575. 1st tone free ! so get txtin now and tell ur friends. 150p/tone. 16 reply hl 4info	1
+1074	ok...	0
+1075	good. do you think you could send me some pix? i would love to see your top and bottom...	0
+1076	ya had just now.onion roast.	0
+1077	how dare you stupid. i wont tell anything to you. hear after i wont talk to you:-.	0
+1078	yes watching footie but worried we're going to blow it - phil neville?	0
+1079	package all your programs well	0
+1080	for many things its an antibiotic and it can be used for chest abdomen and gynae infections even bone infections.	0
+1081	i'll pick you up at about 5.15pm to go to taunton if you still want to come.	0
+1082	g says you never answer your texts, confirm/deny	0
+1083	did u get that message	0
+1084	then why no one talking to me	0
+1085	or maybe my fat fingers just press all these buttons and it doesn't know what to do.	0
+1086	tell me again what your address is	0
+1087	well done and ! luv ya all	0
+1088	you should change your fb to jaykwon thuglyfe falconerf	0
+1089	goodmorning, today i am late for 2hrs. because of back pain.	0
+1090	does uncle timi help in clearing cars	0
+1091	i'm so in love with you. i'm excited each day i spend with you. you make me so happy.	0
+1092	now u sound like manky scouse boy steve,like! i is travelling on da bus home.wot has u inmind 4 recreation dis eve?	0
+1093	aiyar hard 2 type. u later free then tell me then i call n scold n tell u.	0
+1094	bored housewives! chat n date now! 0871750.77.11! bt-national rate 10p/min only from landlines!	1
+1095	kay... since we are out already	0
+1096	i would but i'm still cozy. and exhausted from last night.nobody went to school or work. everything is closed.	0
+1097	had your contract mobile 11 mnths? latest motorola, nokia etc. all free! double mins & text on orange tariffs. text yes for callback, no to remove from records	1
+1098	85233 free>ringtone!reply real	1
+1099	really... i tot ur paper ended long ago... but wat u copied jus now got use? u happy lar... i still haf 2 study :-(	0
+1100	its ok, if anybody asks abt me, u tel them..:-p	0
+1101	cmon babe, make me horny, *turn* me on! txt me your fantasy now babe -) im hot, sticky and need you now. all replies cost ??1.50. 2 cancel send stop	1
+1102	aiyo a bit pai seh ?_ noe... scared he dun rem who i am then die... hee... but he become better lookin oredi leh...	0
+1103	urgent! your mobile number *************** won a ??2000 bonus caller prize on 10/06/03! this is the 2nd attempt to reach you! call 09066368753 asap! box 97n7qp, 150ppm	1
+1104	urgent. important information for 02 user. today is your lucky day! 2 find out why , log onto http://www.urawinner.com there is a fantastic surprise awaiting you !	1
+1105	yup i'm elaborating on the safety aspects and some other issues..	0
+1106	xclusive@clubsaisai 2morow 28/5 soiree speciale zouk with nichols from paris.free roses 2 all ladies !!! info: 07946746291/07880867867	1
+1107	living is very simple.. loving is also simple.. laughing is too simple.. winning is tooo simple.. but, being 'simple' is very difficult...;-) :-)	0
+1108	ok no prob	0
+1109	ay wana meet on sat??_ wkg on sat?	0
+1110	you have an important customer service announcement. call freephone 0800 542 0825 now!	1
+1111	life spend with someone for a lifetime may be meaningless but a few moments spent with someone who really love you means more than life itself..	0
+1112	may b approve panalam...but it should have more posts..	0
+1113	it to 80488. your 500 free text messages are valid until 31 december 2005.	1
+1114	he neva grumble but i sad lor... hee... buy tmr lor aft lunch. but we still meetin 4 lunch tmr a not. neva hear fr them lei. ?? got a lot of work ar?	0
+1115	still in the area of the restaurant. ill try to come back soon	0
+1116	yar he quite clever but aft many guesses lor. he got ask me 2 bring but i thk darren not so willing 2 go. aiya they thk leona still not attach wat.	0
+1117	under the sea, there lays a rock. in the rock, there is an envelope. in the envelope, there is a paper. on the paper, there are 3 words... '	0
+1118	hi mom we might be back later than  &lt;#&gt;	0
+1119	no messages on her phone. i'm holding it now	0
+1120	i know! grumpy old people. my mom was like you better not be lying. then again i am always the one to play jokes...	0
+1121	congrats 2 mobile 3g videophones r yours. call 09063458130 now! videochat wid ur mates, play java games, dload polyph music, noline rentl. bx420. ip4. 5we. 150p	1
+1122	lol yeah at this point i guess not	0
+1123	i am going to sao mu today. will be done only at 12	0
+1124	jamster! to get your free wallpaper text heart to 88888 now! t&c apply. 16 only. need help? call 08701213186.	1
+1125	good afternoon, my love ... how goes your day ? how did you sleep ? i hope your well, my boytoy ... i think of you ...	0
+1126	she.s find. i sent you an offline message to know how anjola's now.	0
+1127	ok u can take me shopping when u get paid =d	0
+1128	lol i know! they're so dramatic. schools already closed for tomorrow. apparently we can't drive in the inch of snow were supposed to get.	0
+1129	ok i'm coming home now.	0
+1130	ok i've sent u da latest version of da project.	0
+1131	<forwarded from 88877>free entry into our ??250 weekly comp just send the word enter to 88877 now. 18 t&c www.textcomp.com	1
+1132	did u download the fring app?	0
+1133	join the uk's horniest dogging service and u can have sex 2nite!. just sign up and follow the instructions. txt entry to 69888 now! nyt.ec2a.3lp.msg@150p	1
+1134	let me know when you've got the money so carlos can make the call	0
+1135	dunno y u ask me.	0
+1136	whos this am in class:-)	0
+1137	he fucking chickened out. he messaged me he would be late and woould buzz me and then i didn't hear a word from him	0
+1138	its hard to believe things like this. all can say lie but think twice before saying anything to me.	0
+1139	no. 1 nokia tone 4 ur mob every week! just txt nok to 87021. 1st tone free ! so get txtin now and tell ur friends. 150p/tone. 16 reply hl 4info	1
+1140	i cant pick the phone right now. pls send a message	0
+1141	got hella gas money, want to go on a grand nature adventure with galileo in a little bit?	0
+1142	4mths half price orange line rental & latest camera phones 4 free. had your phone 11mths ? call mobilesdirect free on 08000938767 to update now! or2stoptxt	1
+1143	so that takes away some money worries	0
+1144	gumby's has a special where a  &lt;#&gt; \ cheese pizza is $2 so i know what we're doin tonight"	0
+1145	customer place i will call you	0
+1146	sir, you will receive the account no another 1hr time. sorry for the delay.	0
+1147	aight, i should be there by 8 at the latest, probably closer to 7. are jay and tyler down or should we just do two trips?	0
+1148	early bird! any purchases yet?	0
+1149	sorry, i'll call later	0
+1150	beautiful truth against gravity.. read carefully: \our heart feels light when someone is in it.. but it feels very heavy when someone leaves it..\" good night"	0
+1151	that way transport is less problematic than on sat night. by the way, if u want to ask  n  to join my bday, feel free. but need to know definite nos as booking on fri.	0
+1152	same as kallis dismissial in 2nd test:-).	0
+1153	yes i think so. i am in office but my lap is in room i think thats on for the last few days. i didnt shut that down	0
+1154	eh sorry leh... i din c ur msg. not sad already lar. me watching tv now. u still in office?	0
+1155	good evening sir, al salam wahleykkum.sharing a happy news.by the grace of god, i got an offer from tayseer,tissco and i joined.hope you are fine.inshah allah,meet you sometime.rakhesh,visitor from india.	0
+1156	* was a nice day and, impressively, i was sensible, went home early and now feel fine. or am i just boring?! when's yours, i can't remember.	0
+1157	yup next stop.	0
+1158	what makes you most happy?	0
+1159	okie... thanx...	0
+1160	adult 18 content your video will be with you shortly	1
+1161	really good:)dhanush rocks once again:)	0
+1162	tell them no need to investigate about me anywhere.	0
+1163	waiting in e car 4 my mum lor. u leh? reach home already?	0
+1164	no rushing. i'm not working. i'm in school so if we rush we go hungry.	0
+1165	aww that's the first time u said u missed me without asking if i missed u first. you do love me! :)	0
+1166	call me da, i am waiting for your call.	0
+1167	no..jst change tat only..	0
+1168	you are awarded a sipix digital camera! call 09061221061 from landline. delivery within 28days. t cs box177. m221bp. 2yr warranty. 150ppm. 16 . p p??3.99	1
+1169	hi! this is roger from cl. how are you?	0
+1170	no it will reach by 9 only. she telling she will be there. i dont know	0
+1171	are you in castor? you need to see something	0
+1172	all sounds good. fingers . makes it difficult to type	0
+1173	last chance! claim ur ??150 worth of discount vouchers today! text shop to 85023 now! savamob, offers mobile! t cs savamob pobox84, m263uz. ??3.00 sub. 16	1
+1174	\response\" is one of d powerful weapon 2 occupy a place in others 'heart'... so	0
+1175	alright we'll bring it to you, see you in like  &lt;#&gt;  mins	0
+1176	well i might not come then...	0
+1177	at the funeral home with audrey and dad	0
+1178	can not use foreign stamps in this country. good lecture .	0
+1179	yeah that'd pretty much be the best case scenario	0
+1180	i can't describe how lucky you are that i'm actually awake by noon	0
+1181	ok... let u noe when i leave my house.	0
+1182	its not the same here. still looking for a job. how much do ta's earn there.	0
+1183	noice. text me when you're here	0
+1184	looks like u wil b getting a headstart im leaving here bout 2.30ish but if u r desperate for my company i could head in earlier-we were goin to meet in rummer.	0
+1185	1) go to write msg 2) put on dictionary mode 3)cover the screen with hand, 4)press  &lt;#&gt; . 5)gently remove ur hand.. its interesting..:)	0
+1186	have you always been saying welp?	0
+1187	pls come quick cant bare this.	0
+1188	two teams waiting for some players	0
+1189	as per your request 'melle melle (oru minnaminunginte nurungu vettam)' has been set as your callertune for all callers. press *9 to copy your friends callertune	0
+1190	good morning plz call me sir	0
+1191	hi! you just spoke to maneesha v. we'd like to know if you were satisfied with the experience. reply toll free with yes or no.	0
+1192	as a registered optin subscriber ur draw 4 ??100 gift voucher will be entered on receipt of a correct ans to 80062 whats no1 in the bbc charts	1
+1193	ever green quote ever told by jerry in cartoon \a person who irritates u always is the one who loves u vry much but fails to express it...!..!! :-) :-) gud nyt"	0
+1194	:-( that's not v romantic!	0
+1195	i surely dont forgot to come:)i will always be in touch in with you:-)	0
+1196	no i'm not. i can't give you everything you want and need. you actually could do better for yourself on yor own--you've got more money than i do. i can't get work, i can't get a man, i can't pay the rent, i can't even fill my fucking gas tank. yes, i'm stressed and depressed. i didn't even call home for thanksgiving cuz i'll have to tell them i,m up to nothing.	0
+1197	sorry,in meeting i'll call later	0
+1198	when is school starting. where will you stay. what's the weather like. and the food. do you have a social support system like friends in the school. all these things are important.	0
+1199	dunno leh cant remember mayb lor. so wat time r we meeting tmr?	0
+1200	loan for any purpose ??500 - ??75,000. homeowners + tenants welcome. have you been previously refused? we can still help. call free 0800 1956669 or text back 'help'	1
+1201	i'm vivek:)i got call from your number.	0
+1202	omg i want to scream. i weighed myself and i lost more weight! woohoo!	0
+1203	my love ... i hope your not doing anything drastic. don't you dare sell your pc or your phone ...	0
+1204	huh... hyde park not in mel ah, opps, got confused... anyway, if tt's e best choice den we juz have to take it...	0
+1205	do you think i can move  &lt;#&gt;  in a week	0
+1206	some are lasting as much as 2 hours. you might get lucky.	0
+1207	oh ok i didnt know what you meant. yep i am baby jontin	0
+1208	how abt making some of the pics bigger?	0
+1209	mmmm.... i cant wait to lick it!	0
+1210	your opinion about me? 1. over 2. jada 3. kusruthi 4. lovable 5. silent 6. spl character 7. not matured 8. stylish 9. simple pls reply..	0
+1211	todays vodafone numbers ending with 0089(my last four digits) are selected to received a ??350 award. if your number matches please call 09063442151 to claim your ??350 award	1
+1212	yes:)from last week itself i'm taking live call.	0
+1213	08714712388 between 10am-7pm cost 10p	1
+1214	a bit of ur smile is my hppnss, a drop of ur tear is my sorrow, a part of ur heart is my life, a heart like mine wil care for u, forevr as my goodfriend	0
+1215	you won't believe it but it's true. it's incredible txts! reply g now to learn truly amazing things that will blow your mind. from o2fwd only 18p/txt	1
+1216	i thk 530 lor. but dunno can get tickets a not. wat u doing now?	0
+1217	in case you wake up wondering where i am, i forgot i have to take care of something for grandma today, should be done before the parade	0
+1218	no da:)he is stupid da..always sending like this:)don believe any of those message.pandy is a mental:)	0
+1219	check audrey's status right now	0
+1220	fyi i'm gonna call you sporadically starting at like  &lt;#&gt;  bc we are not not doin this shit	0
+1221	i want to be there so i can kiss you and feel you next to me	0
+1222	-pls stop bootydelious (32/f) is inviting you to be her friend. reply yes-434 or no-434 see her: www.sms.ac/u/bootydelious stop? send stop frnd to 62468	1
+1223	\not enufcredeit tocall.shall ileave uni at 6 +get a bus to yor house?\""	0
+1224	the house is on the water with a dock, a boat rolled up with a newscaster who dabbles in jazz flute behind the wheel	0
+1225	i can ask around but there's not a lot in terms of mids up here	0
+1226	nothing much, chillin at home. any super bowl plan?	0
+1227	i???m parked next to a mini!!!! when are you coming in today do you think?	0
+1228	it???s ??6 to get in, is that ok?	0
+1229	1.20 that call cost. which i guess isnt bad. miss ya, need ya, want ya, love ya	0
+1230	i know complain num only..bettr directly go to bsnl offc nd apply for it..	0
+1231	i am going to bed now prin	0
+1232	u can win ??100 of music gift vouchers every week starting now txt the word draw to 87066 tscs www.idew.com skillgame, 1winaweek, age16. 150ppermesssubscription	1
+1233	oh just getting even with u.... u?	0
+1234	3 pa but not selected.	0
+1235	how. its a little difficult but its a simple way to enter this place	0
+1236	honeybee said: *i'm d sweetest in d world* god laughed &amp; said: *wait,u havnt met d person reading this msg* moral: even god can crack jokes! gm+gn+ge+gn:)	0
+1237	alright we're hooked up, where you guys at	0
+1238	hi shanil,rakhesh here.thanks,i have exchanged the uncut diamond stuff.leaving back. excellent service by dino and prem.	0
+1239	?? come lt 25 n pass to me lar	0
+1240	no let me do the math. your not good at it.	0
+1241	i doubt you could handle 5 times per night in any case...	0
+1242	hi i won't b ard 4 christmas. but do enjoy n merry x'mas.	0
+1243	lol ... no just was busy	0
+1244	sorry, i'll call later ok bye	0
+1245	oh ok..	0
+1246	can't. i feel nauseous. i'm so pissed. i didn't eat any sweets all week cause today i was planning to pig out. i was dieting all week. and now i'm not hungry :/	0
+1247	we have sent jd for customer service cum accounts executive to ur mail id, for details contact us	0
+1248	congratulations u can claim 2 vip row a tickets 2 c blu in concert in november or blu gift guaranteed call 09061104276 to claim ts&cs www.smsco.net cost??3.75max	1
+1249	also remember the beads don't come off. ever.	0
+1250	get me out of this dump heap. my mom decided to come to lowes. boring.	0
+1251	i love to wine and dine my lady!	0
+1252	i dnt wnt to tlk wid u	0
+1253	i cant pick the phone right now. pls send a message	0
+1254	well, i meant as opposed to my drunken night of before	0
+1255	now get step 2 outta the way. congrats again.	0
+1256	awesome, think we can get an 8th at usf some time tonight?	0
+1257	well you told others you'd marry them...	0
+1258	this pain couldn't have come at a worse time.	0
+1259	dear :-/ why you mood off. i cant drive so i brother to drive	0
+1260	well the weather in cali's great. but its complexities are great. you need a car to move freely, its taxes are outrageous. but all in all its a great place. the sad part is i missing home.	0
+1261	okie	0
+1262	what r u cooking me for dinner?	0
+1263	happy new years melody!	0
+1264	yeah probably, i still gotta check out with leo	0
+1265	where @	0
+1266	my life means a lot to me, not because i love my life, but because i love the people in my life, the world calls them friends, i call them my world:-).. ge:-)..	0
+1267	shopping? eh ger i toking abt syd leh...haha	0
+1268	thanx 4 the time we??ve spent 2geva, its bin mint! ur my baby and all i want is u!xxxx	0
+1269	please protect yourself from e-threats. sib never asks for sensitive information like passwords,atm/sms pin thru email. never share your password with anybody.	0
+1270	did u find a sitter for kaitlyn? i was sick and slept all day yesterday.	0
+1271	pls speak with me. i wont ask anything other then you friendship.	0
+1272	i think the other two still need to get cash but we can def be ready by 9	0
+1273	not yet had..ya sapna aunty manege y'day hogidhe..chinnu full weak and swalpa black agidhane..	0
+1274	for my family happiness..	0
+1275	thank you for calling.forgot to say happy onam to you sirji.i am fine here and remembered you when i met an insurance person.meet you in qatar insha allah.rakhesh, ex tata aig who joined tissco,tayseer.	0
+1276	sending you greetings of joy and happiness. do have a gr8 evening	0
+1277	feel yourself that you are always happy.. slowly it becomes your habit &amp; finally it becomes part of your life.. follow it.. happy morning &amp; have a happy day:)	0
+1278	so why didnt you holla?	0
+1279	is it your yahoo boys that bring in the perf? or legal.	0
+1280	dorothy@kiefer.com (bank of granite issues strong-buy) explosive pick for our members *****up over 300% *********** nasdaq symbol cdgt that is a $5.00 per..	1
+1281	k.k..how is your sister kids?	0
+1282	roger that. we???re probably going to rem in about 20	0
+1283	yeah sure, give me a couple minutes to track down my wallet	0
+1284	what happen dear. why you silent. i am tensed	0
+1285	don't b floppy... b snappy & happy! only gay chat service with photo upload call 08718730666 (10p/min). 2 stop our texts call 08712460324	1
+1286	easy ah?sen got selected means its good..	0
+1287	wen ur lovable bcums angry wid u, dnt take it seriously.. coz being angry is d most childish n true way of showing deep affection, care n luv!.. kettoda manda... have nice day da.	0
+1288	fine i miss you very much.	0
+1289	latest news! police station toilet stolen, cops have nothing to go on!	1
+1290	very strange.  and  are watching the 2nd one now but i'm in bed. sweet dreams, miss u	0
+1291	do 1 thing! change that sentence into: \because i want 2 concentrate in my educational career im leaving here..\""	0
+1292	no problem. we will be spending a lot of quality time together...	0
+1293	double mins and txts 4 6months free bluetooth on orange. available on sony, nokia motorola phones. call mobileupd8 on 08000839402 or call2optout/n9dx	1
+1294	i'm back &amp; we're packing the car now, i'll let you know if there's room	0
+1295	well keep in mind i've only got enough gas for one more round trip barring a sudden influx of cash	0
+1296	cashbin.co.uk (get lots of cash this weekend!) www.cashbin.co.uk dear welcome to the weekend we have got our biggest and best ever cash give away!! these..	1
+1297	raji..pls do me a favour. pls convey my birthday wishes to nimya. pls. today is her birthday.	0
+1298	cool, want me to go to kappa or should i meet you outside mu	0
+1299	hey, a guy i know is breathing down my neck to get him some bud, anyway you'd be able to get a half track to usf tonight?	0
+1300	nothing. can...	0
+1301	you can donate ??2.50 to unicef's asian tsunami disaster support fund by texting donate to 864233. ??2.50 will be added to your next bill	1
+1302	dont let studying stress you out. l8r.	0
+1303	can you talk with me..	0
+1304	dear where you will be when i reach there	0
+1305	yar lor... how u noe? u used dat route too?	0
+1306	umma. did she say anything	0
+1307	what's a feathery bowa? is that something guys have that i don't know about?	0
+1308	neft transaction with reference number  &lt;#&gt;  for rs. &lt;decimal&gt;  has been credited to the beneficiary account on  &lt;#&gt;  at  &lt;time&gt; : &lt;#&gt;	0
+1309	mmmmmm ... i love you,so much, ahmad ... i can't wait for this year to begin as every second takes me closer to being at your side. happy new year, my love!!	0
+1310	from here after the performance award is calculated every two month.not for current one month period..	0
+1311	i don't want you to leave. but i'm barely doing what i can to stay sane. fighting with you constantly isn't helping.	0
+1312	no it's waiting in e car dat's bored wat. cos wait outside got nothing 2 do. at home can do my stuff or watch tv wat.	0
+1313	his frens go then he in lor. not alone wif my mum n sis lor.	0
+1314	am okay. will soon be over. all the best	0
+1315	free 1st week entry 2 textpod 4 a chance 2 win 40gb ipod or ??250 cash every wk. txt pod to 84128 ts&cs www.textpod.net custcare 08712405020.	1
+1316	he says hi and to get your ass back to south tampa (preferably at a kegger)	0
+1317	honey ? sweetheart ? darling ? sexy buns ? sugar plum ? loverboy ? i miss you, boytoy ... *smacks your ass* did you go to the gym too ?	0
+1318	po de :-):):-):-):-). no need job aha.	0
+1319	one of the joys in lifeis waking up each daywith thoughts that somewheresomeone cares enough tosend a warm morning greeting.. -	0
+1320	have * good weekend.	0
+1321	just re read it and i have no shame but tell me how he takes it and if he runs i will blame u 4 ever!! not really 4 ever just a long time	0
+1322	oh k k:)but he is not a big hitter.anyway good	0
+1323	hey girl. how r u? hope u r well me an del r bak! again long time no c! give me a call sum time from lucyxx	0
+1324	plz note: if anyone calling from a mobile co. &amp; asks u to type # &lt;#&gt;  or # &lt;#&gt; . do not do so. disconnect the call,coz it iz an attempt of 'terrorist' to make use of the sim card no. itz confirmd by nokia n motorola n has been verified by cnn ibn.	0
+1325	they released another italian one today and it has a cosign option	0
+1326	shop till u drop, is it you, either 10k, 5k, ??500 cash or ??100 travel voucher, call now, 09064011000. ntt po box cr01327bt fixedline cost 150ppm mobile vary	1
+1327	thanks for your message. i really appreciate your sacrifice. i'm not sure of the process of direct pay but will find out on my way back from the test tomorrow. i'm in class now. do have a wonderful day.	0
+1328	someone has contacted our dating service and entered your phone because they fancy you! to find out who it is call from a landline 09111032124 . pobox12n146tf150p	1
+1329	i guess that's why you re worried. you must know that there's a way the body repairs itself. and i'm quite sure you shouldn't worry. we'll take it slow. first the tests, they will guide when your ovulation is then just relax. nothing you've said is a reason to worry but i.ll keep on followin you up.	0
+1330	&lt;decimal&gt; m but its not a common car here so its better to buy from china or asia. or if i find it less expensive. i.ll holla	0
+1331	ok...	0
+1332	night has ended for another day, morning has come in a special way. may you smile like the sunny rays and leaves your worries at the blue blue bay. gud mrng	0
+1333	meet you in corporation st outside gap ??_ you can see how my mind is working!	0
+1334	even if he my friend he is a priest call him now	0
+1335	wife.how she knew the time of murder exactly	0
+1336	if you want to mapquest it or something look up \usf dogwood drive\"	0
+1337	oh ok.. wat's ur email?	0
+1338	want 2 get laid tonight? want real dogging locations sent direct 2 ur mob? join the uk's largest dogging network by txting moan to 69888nyt. ec2a. 31p.msg@150p	1
+1339	just arrived, see you in a couple days &lt;3	0
+1340	it's cool, we can last a little while. getting more any time soon?	0
+1341	company is very good.environment is terrific and food is really nice:)	0
+1342	i can't keep going through this. it was never my intention to run you out, but if you choose to do that rather than keep the room clean so *i* don't have to say no to visitors, then maybe that's the best choice. yes, i wanted you to be embarassed, so maybe you'd feel for once how i feel when i have a friend who wants to drop buy and i have to say no, as happened this morning. i've tried everything. i don't know what else to do.	0
+1343	hey babe, my friend had to cancel, still up for a visit ?	0
+1344	urgent ur awarded a complimentary trip to eurodisinc trav, aco&entry41 or ??1000. to claim txt dis to 87121 18+6*??1.50(morefrmmob. shracomorsglsuplt)10, ls1 3aj	1
+1345	the wine is flowing and i'm i have nevering..	0
+1346	hi darlin how was work did u get into trouble? ijust talked to your mum all morning! i had a really good time last night im goin out soon but call me if u can	0
+1347	i'm home, my love ... if your still awake ... *loving kiss*	0
+1348	i'm in chennai velachery:)	0
+1349	was doing my test earlier. i appreciate you. will call you tomorrow.	0
+1350	lol you won't feel bad when i use her money to take you out to a steak dinner =d	0
+1351	me too watching surya movie only. . .after 6 pm vijay movie pokkiri	0
+1352	i hope your pee burns tonite.	0
+1353	depends on where u going lor.	0
+1354	i like you peoples very much:) but am very shy pa.	0
+1355	i'm at work. please call	0
+1356	i actually did for the first time in a while. i went to bed not too long after i spoke with you. woke up at 7. how was your night?	0
+1357	just glad to be talking to you.	0
+1358	what year. and how many miles.	0
+1359	you have won a nokia 7250i. this is what you get when you win our free auction. to take part send nokia to 86021 now. hg/suite342/2lands row/w1jhl 16+	1
+1360	having lunch:)you are not in online?why?	0
+1361	he's just gonna worry for nothing. and he won't give you money its no use.	0
+1362	ok enjoy . r u there in home.	0
+1363	good morning princess! how are you?	0
+1364	i was at bugis juz now wat... but now i'm walking home oredi... ?? so late then reply... i oso saw a top dat i like but din buy... where r ?_ now?	0
+1365	nobody names their penis a girls name this story doesn't add up at all	0
+1366	lol i know! hey someone did a great inpersonation of flea on the forums. i love it!	0
+1367	yeah, probably earlier than that	0
+1368	arun can u transfr me d amt	0
+1369	ok k..sry i knw 2 siva..tats y i askd..	0
+1370	yup but not studying surfing lor. i'm in e lazy mode today.	0
+1371	nan sonathaya soladha. why boss?	0
+1372	i am sorry it hurt you.	0
+1373	oh fine, i'll be by tonight	0
+1374	dont gimme that lip caveboy	0
+1375	lol! u drunkard! just doing my hair at d moment. yeah still up 4 tonight. wats the plan?	0
+1376	really do hope the work doesnt get stressful. have a gr8 day.	0
+1377	short but cute : \ be a good person	0
+1378	yes.he have good crickiting mind	0
+1379	so how's the weather over there?	0
+1380	i'm on da bus going home...	0
+1381	not able to do anything.	0
+1382	i'm serious. you are in the money base	0
+1383	well imma definitely need to restock before thanksgiving, i'll let you know when i'm out	0
+1384	dunno da next show aft 6 is 850. toa payoh got 650.	0
+1385	oops i thk i dun haf enuff... i go check then tell ?_..	0
+1386	its on in engalnd! but telly has decided it won't let me watch it and mia and elliot were kissing! damn it!	0
+1387	ela kano.,il download, come wen ur free..	0
+1388	i dunno until when... lets go learn pilates...	0
+1389	hi chikku, send some nice msgs	0
+1390	when you and derek done with class?	0
+1391	sent me ur email id soon	0
+1392	yup	0
+1393	sorry, i'll call later	0
+1394	i sent your maga that money yesterday oh.	0
+1395	dude im no longer a pisces. im an aquarius now.	0
+1396	urgent ur awarded a complimentary trip to eurodisinc trav, aco&entry41 or ??1000. to claim txt dis to 87121 18+6*??1.50(morefrmmob. shracomorsglsuplt)10, ls1 3aj	1
+1397	sorry! u can not unsubscribe yet. the mob offer package has a min term of 54 weeks> pls resubmit request after expiry. reply themob help 4 more info	1
+1398	k, can that happen tonight?	0
+1399	i came hostel. i m going to sleep. plz call me up before class. hrishi.	0
+1400	as in i want custom officer discount oh.	0
+1401	realy sorry-i don't recognise this number and am now confused :) who r u please?!	0
+1402	fair enough, anything going on?	0
+1403	please call our customer service representative on freephone 0808 145 4742 between 9am-11pm as you have won a guaranteed ??1000 cash or ??5000 prize!	1
+1404	anything lar then ?_ not going home 4 dinner?	0
+1405	k:)i will give my kvb acc details:)	0
+1406	hey sexy buns! what of that day? no word from you this morning on ym ... :-( ... i think of you	0
+1407	house-maid is the murderer, coz the man was murdered on  &lt;#&gt; th january.. as public holiday all govt.instituitions are closed,including post office..understand?	0
+1408	hi , where are you? we're at  and they're not keen to go out i kind of am but feel i shouldn't so can we go out tomo, don't mind do you?	0
+1409	?? mean it's confirmed... i tot they juz say oni... ok then...	0
+1410	i didnt get ur full msg..sometext is missing, send it again	0
+1411	how are you, my love ? are you with your brother ? time to talk english with him ? *grins* say : hey muhommad, penny says hello from across the sea	0
+1412	k...k...yesterday i was in cbe .	0
+1413	great escape. i fancy the bridge but needs her lager. see you tomo	0
+1414	idea will soon get converted to live:)	0
+1415	i'll probably be by tomorrow (or even later tonight if something's going on)	0
+1416	mm not entirely sure i understood that text but hey. ho. which weekend?	0
+1417	what's ur pin?	0
+1418	i am late. i will be there at	0
+1419	sppok up ur mob with a halloween collection of nokia logo&pic message plus a free eerie tone, txt card spook to 8007	1
+1420	o shore are you takin the bus	0
+1421	we're done...	0
+1422	as usual u can call me ard 10 smth.	0
+1423	thanks, i'll keep that in mind	0
+1424	how dare you change my ring	0
+1425	i'll be in sch fr 4-6... i dun haf da book in sch... it's at home...	0
+1426	win a year supply of cds 4 a store of ur choice worth ??500 & enter our ??100 weekly draw txt music to 87066 ts&cs www.ldew.com.subs16+1win150ppmx3	1
+1427	and maybe some pressies	0
+1428	nope i'll come online now..	0
+1429	give me a sec to think think about it	0
+1430	i shall book chez jules for half eight, if that's ok with you?	0
+1431	msg me when rajini comes.	0
+1432	ur awarded a city break and could win a ??200 summer shopping spree every wk. txt store to 88039 . skilgme. tscs087147403231winawk!age16 ??1.50perwksub	1
+1433	nt only for driving even for many reasons she is called bbd..thts it chikku, then hw abt dvg cold..heard tht vinobanagar violence hw is the condition..and hw ru ? any problem?	0
+1434	sorry de i went to shop.	0
+1435	great! i hope you like your man well endowed. i am  &lt;#&gt;  inches...	0
+1436	talk with yourself atleast once in a day...!!! otherwise you will miss your best friend in this world...!!! -shakespeare- shesil  &lt;#&gt;	0
+1437	as per your request 'maangalyam (alaipayuthe)' has been set as your callertune for all callers. press *9 to copy your friends callertune	0
+1438	you made my day. do have a great day too.	0
+1439	best line said in love: . \i will wait till the day i can forget u or the day u realize that u cannot forget me.\"... gn"	0
+1440	update_now - xmas offer! latest motorola, sonyericsson & nokia & free bluetooth! double mins & 1000 txt on orange. call mobileupd8 on 08000839402 or call2optout/f4q=	1
+1441	i'll text carlos and let you know, hang on	0
+1442	st andre, virgil's cream	0
+1443	awesome, be there in a minute	0
+1444	mm have some kanji dont eat anything heavy ok	0
+1445	congratulations ur awarded 500 of cd vouchers or 125gift guaranteed & free entry 2 100 wkly draw txt music to 87066	1
+1446	no shit, but i wasn't that surprised, so i went and spent the evening with that french guy i met in town here and we fooled around a bit but i didn't let him fuck me	0
+1447	if you mean the website. yes.	0
+1448	aldrine, rakhesh ex rtm here.pls call.urgent.	0
+1449	hi, wkend ok but journey terrible. wk not good as have huge back log of marking to do	0
+1450	new theory: argument wins d situation, but loses the person. so dont argue with ur friends just.. . . . kick them &amp; say, i'm always correct.!	0
+1451	usf i guess, might as well take 1 car	0
+1452	r u here yet? i'm wearing blue shirt n black pants.	0
+1453	no. i dont want to hear anything	0
+1454	ok c ?_ then.	0
+1455	why is that, princess? i bet the brothas are all chasing you!	0
+1456	omw back to tampa from west palm, you hear what happened?	0
+1457	ho ho - big belly laugh! see ya tomo	0
+1458	please dont say like that. hi hi hi	0
+1459	sorry * was at the grocers.	0
+1460	it's really getting me down just hanging around.	0
+1461	well done! your 4* costa del sol holiday or ??5000 await collection. call 09050090044 now toclaim. sae, tcs, pobox334, stockport, sk38xh, cost??1.50/pm, max10mins	1
+1462	urgent! your mobile no 07808726822 was awarded a ??2,000 bonus caller prize on 02/09/03! this is our 2nd attempt to contact you! call 0871-872-9758 box95qu	1
+1463	i went to ur hon lab but no one is there.	0
+1464	dear voucher holder, to claim this weeks offer, at you pc please go to http://www.e-tlp.co.uk/expressoffer ts&cs apply. to stop texts, txt stop to 80062	1
+1465	do u hav any frnd by name ashwini in ur college?	0
+1466	x course it 2yrs. just so her messages on messenger lik you r sending me	0
+1467	i need you to be in my strong arms...	0
+1468	how are you. just checking up on you	0
+1469	nice line said by a broken heart- plz don't cum 1 more times infront of me... other wise once again i ll trust u... good 9t:)	0
+1470	think i could stop by in like an hour or so? my roommate's looking to stock up for a trip	0
+1471	i think asking for a gym is the excuse for lazy people. i jog.	0
+1472	i said its okay. sorry	0
+1473	friendship poem: dear o dear u r not near but i can hear dont get fear live with cheer no more tear u r always my dear. gud ni8	0
+1474	that's y we haf to combine n c how lor...	0
+1475	tick, tick, tick .... where are you ? i could die of loneliness you know ! *pouts* *stomps feet* i need you ...	0
+1476	in sch but neva mind u eat 1st lor..	0
+1477	gokila is talking with you aha:)	0
+1478	spook up your mob with a halloween collection of a logo & pic message plus a free eerie tone, txt card spook to 8007 zed 08701417012150p per logo/pic	1
+1479	twinks, bears, scallies, skins and jocks are calling now. don't miss the weekend's fun. call 08712466669 at 10p/min. 2 stop texts call 08712460324(nat rate)	1
+1480	aight, i'll hit you up when i get some cash	0
+1481	hi. i'm always online on yahoo and would like to chat with you someday	0
+1482	hi 07734396839 ibh customer loyalty offer: the new nokia6600 mobile from only ??10 at txtauction!txt word:start to no:81151 & get yours now!4t&	1
+1483	water logging in desert. geoenvironmental implications.	0
+1484	are you unique enough? find out from 30th august. www.areyouunique.co.uk	1
+1485	have you ever had one foot before?	0
+1486	now press conference da:)	0
+1487	me sef dey laugh you. meanwhile how's my darling anjie!	0
+1488	when you guys planning on coming over?	0
+1489	india have to take lead:)	0
+1490	probably not, i'm almost out of gas and i get some cash tomorrow	0
+1491	huh y lei...	0
+1492	wherre's my boytoy ? :-(	0
+1493	my uncles in atlanta. wish you guys a great semester.	0
+1494	mmm thats better now i got a roast down me! i??d b better if i had a few drinks down me 2! good indian?	0
+1495	i tot it's my group mate... lucky i havent reply... wat time do ?_ need to leave...	0
+1496	do you want a new video phone? 600 anytime any network mins 400 inclusive video calls and downloads 5 per week free deltomorrow call 08002888812 or reply now	1
+1497	i am in hospital da. . i will return home in evening	0
+1498	\im on gloucesterroad what are uup to later?\""	0
+1499	well that must be a pain to catch	0
+1500	dear,regret i cudnt pick call.drove down frm ctla now at cochin home.left mobile in car..ente style ishtamayoo?happy bakrid!	0
+1501	someone u know has asked our dating service 2 contact you! cant guess who? call 09058095107 now all will be revealed. pobox 7, s3xy 150p	1
+1502	funny fact nobody teaches volcanoes 2 erupt, tsunamis 2 arise, hurricanes 2 sway aroundn no 1 teaches hw 2 choose a wife natural disasters just happens	0
+1503	will you be here for food	0
+1504	aiyo... her lesson so early... i'm still sleepin, haha... okie, u go home liao den confirm w me lor...	0
+1505	hope you are feeling great. pls fill me in. abiola	0
+1506	k:)k..its good:)when are you going?	0
+1507	i noe la... u wana pei bf oso rite... k lor, other days den...	0
+1508	had your contract mobile 11 mnths? latest motorola, nokia etc. all free! double mins & text on orange tariffs. text yes for callback, no to remove from records.	1
+1509	alright. thanks for the advice. enjoy your night out. i'ma try to get some sleep...	0
+1510	an excellent thought by a misundrstud frnd: i knw u hate me bt the day wen u'll knw the truth u'll hate urself:-( gn:-)	0
+1511	i like to talk pa but am not able to. i dont know y.	0
+1512	show ur colours! euro 2004 2-4-1 offer! get an england flag & 3lions tone on ur phone! click on the following service message for info!	1
+1513	how long before you get reply, just defer admission til next semester	0
+1514	we tried to call you re your reply to our sms for a video mobile 750 mins unlimited text + free camcorder reply of call 08000930705 now	1
+1515	but your brother transfered only  &lt;#&gt;  +  &lt;#&gt; . pa.	0
+1516	havent shopping now lor i juz arrive only	0
+1517	yo, the game almost over? want to go to walmart soon	0
+1518	just seeing your missed call my dear brother. do have a gr8 day.	0
+1519	friendship is not a game to play, it is not a word to say, it doesn\'t start on march and ends on may, it is tomorrow, yesterday, today and e	0
+1520	alex says he's not ok with you not being ok with it	0
+1521	free tones hope you enjoyed your new content. text stop to 61610 to unsubscribe. help:08712400602450p provided by tones2you.co.uk	1
+1522	hmm. shall i bring a bottle of wine to keep us amused? just joking! i'll still bring a bottle. red or white? see you tomorrow	0
+1523	and is there a way you can send shade's stuff to her. and she has been wonderful too.	0
+1524	under the sea, there lays a rock. in the rock, there is an envelope. in the envelope, there is a paper. on the paper, there are 3 words... '	0
+1525	new car and house for my parents.:)i have only new job in hand:)	0
+1526	gud ni8 dear..slp well..take care..swt dreams..muah..	0
+1527	thanx4 today cer it was nice 2 catch up but we ave 2 find more time more often oh well take care c u soon.c	0
+1528	then ?_ wait 4 me at bus stop aft ur lect lar. if i dun c ?_ then i go get my car then come back n pick ?_.	0
+1529	am on a train back from northampton so i'm afraid not! i'm staying skyving off today ho ho! will be around wednesday though. do you fancy the comedy club this week by the way?	0
+1530	do 1 thing! change that sentence into: \because i want 2 concentrate in my educational career im leaving here..\""	0
+1531	no. 1 nokia tone 4 ur mob every week! just txt nok to 87021. 1st tone free ! so get txtin now and tell ur friends. 150p/tone. 16 reply hl 4info	1
+1532	howz pain.it will come down today.do as i said ystrday.ice and medicine.	0
+1533	this is one of the days you have a billion classes, right?	0
+1534	??_ we r stayin here an extra week, back next wed. how did we do in the rugby this weekend? hi to and and , c u soon	0
+1535	r u in this continent?	0
+1536	it should take about  &lt;#&gt;  min	0
+1537	i cant pick the phone right now. pls send a message	0
+1538	lol for real. she told my dad i have cancer	0
+1539	i see. when we finish we have loads of loans to pay	0
+1540	jason says it's cool if we pick some up from his place in like an hour	0
+1541	yay can't wait to party together!	0
+1542	\hi missed your call and my mumhas beendropping red wine all over theplace! what is your adress?\""	0
+1543	i call you later, don't have network. if urgnt, sms me.	0
+1544	(no promises on when though, haven't even gotten dinner yet)	0
+1545	here is your discount code rp176781. to stop further messages reply stop. www.regalportfolio.co.uk. customer services 08717205546	1
+1546	came to look at the flat, seems ok, in his 50s? * is away alot wiv work. got woman coming at 6.30 too.	0
+1547	yup ?_ not comin :-(	0
+1548	yes! i am a one woman man! please tell me your likes and dislikes in bed...	0
+1549	hi hope u get this txt~journey hasnt been gd,now about 50 mins late i think.	0
+1550	come to my home for one last time i wont do anything. trust me.	0
+1551	so what do you guys do.	0
+1552	gal n boy walking in d park. gal-can i hold ur hand? boy-y? do u think i would run away? gal-no, jst wana c how it feels walking in heaven with an prince..gn:-)	0
+1553	as i entered my cabin my pa said, '' happy b'day boss !!''. i felt special. she askd me 4 lunch. after lunch she invited me to her apartment. we went there.	0
+1554	your gonna have to pick up a $1 burger for yourself on your way home. i can't even move. pain is killing me.	0
+1555	i wnt to buy a bmw car urgently..its vry urgent.but hv a shortage of  &lt;#&gt; lacs.there is no source to arng dis amt. &lt;#&gt; lacs..thats my prob	0
+1556	haven't left yet so probably gonna be here til dinner	0
+1557	jokin only lar... :-) depends on which phone my father can get lor...	0
+1558	i wont touch you with out your permission.	0
+1559	ujhhhhhhh computer shipped out with address to sandiago and parantella lane. wtf. poop.	0
+1560	dude got a haircut. now its breezy up there	0
+1561	hello. we need some posh birds and chaps to user trial prods for champneys. can i put you down? i need your address and dob asap. ta r	1
+1562	i'm done. c ?_ there.	0
+1563	hi msg me:)i'm in office..	0
+1564	nothing but we jus tot u would ask cos u ba gua... but we went mt faber yest... yest jus went out already mah so today not going out... jus call lor...	0
+1565	sorry, i guess whenever i can get a hold of my connections, maybe an hour or two? i'll text you	0
+1566	yeah i imagine he would be really gentle. unlike the other docs who treat their patients like turkeys.	0
+1567	i lost 4 pounds since my doc visit last week woot woot! now i'm gonna celebrate by stuffing my face!	0
+1568	i got your back! do you have any dislikes in bed?	0
+1569	can ?_ send me a copy of da report?	0
+1570	love you aathi..love u lot..	0
+1571	siva is in hostel aha:-.	0
+1572	can you plz tell me the ans. bslvyl sent via fullonsms.com	0
+1573	ya very nice. . .be ready on thursday	0
+1574	reckon need to be in town by eightish to walk from * carpark.	0
+1575	love that holiday monday feeling even if i have to go to the dentists in an hour	0
+1576	call me when u finish then i come n pick u.	0
+1577	me too baby! i promise to treat you well! i bet you will take good care of me...	0
+1578	boo. how's things? i'm back at home and a little bored already :-(	0
+1579	hot live fantasies call now 08707500020 just 20p per min ntt ltd, po box 1327 croydon cr9 5wb 0870 is a national rate call	1
+1580	send this to ur friends and receive something about ur voice..... how is my speaking expression? 1.childish 2.naughty 3.sentiment 4.rowdy 5.ful of attitude 6.romantic 7.shy 8.attractive 9.funny  &lt;#&gt; .irritating  &lt;#&gt; .lovable. reply me..	0
+1581	are we doing the norm tomorrow? i finish just a 4.15 cos of st tests. need to sort library stuff out at some point tomo - got letter from today - access til end march so i better get move on!	0
+1582	if u sending her home first it's ok lor. i'm not ready yet.	0
+1583	a few people are at the game, i'm at the mall with iouri and kaila	0
+1584	last chance! claim ur ??150 worth of discount vouchers today! text shop to 85023 now! savamob, offers mobile! t cs savamob pobox84, m263uz. ??3.00 sub. 16	1
+1585	nt joking seriously i told	0
+1586	like i made him throw up when we were smoking in our friend's car one time, it was awesome	0
+1587	hey... thk we juz go accordin to wat we discussed yest lor, except no kb on sun... cos there's nt much lesson to go if we attend kb on sat...	0
+1588	squeeeeeze!! this is christmas hug.. if u lik my frndshp den hug me back.. if u get 3 u r cute:) 6 u r luvd:* 9 u r so lucky;) none? people hate u:	0
+1589	but my family not responding for anything. now am in room not went to home for diwali but no one called me and why not coming. it makes me feel like died.	0
+1590	thursday night? yeah, sure thing, we'll work it out then	0
+1591	thanks for loving me so. you rock	0
+1592	thts god's gift for birds as humans hav some natural gift frm god..	0
+1593	hello! how r u? im bored. inever thought id get bored with the tv but i am. tell me something exciting has happened there? anything! =/	0
+1594	call 09095350301 and send our girls into erotic ecstacy. just 60p/min. to stop texts call 08712460324 (nat rate)	1
+1595	no..he joined today itself.	0
+1596	can you say what happen	0
+1597	i'm ok. will do my part tomorrow	0
+1598	yeah go on then, bored and depressed sittin waitin for phone to ring... hope the wind drops though, scary	0
+1599	tell them the drug dealer's getting impatient	0
+1600	yes. that will be fine. love you. be safe.	0
+1601	yes princess! i want to catch you with my big strong hands...	0
+1602	not heard from u4 a while. call me now am here all night with just my knickers on. make me beg for it like u did last time 01223585236 xx luv nikiyu4.net	1
+1603	how come u got nothing to do?	0
+1604	be sure to check your yahoo email. we sent photos yesterday	0
+1605	what time you thinkin of goin?	0
+1606	urgent! we are trying to contact u. todays draw shows that you have won a ??2000 prize guaranteed. call 09058094507 from land line. claim 3030. valid 12hrs only	1
+1607	orh i tot u say she now still dun believe.	0
+1608	hey sweet, i was wondering when you had a moment if you might come to me ? i want to send a file to someone but it won't go over yahoo for them because their connection sucks, remember when you set up that page for me to go to and download the format disc ? could you tell me how to do that ? or do you know some other way to download big files ? because they can download stuff directly from the internet. any help would be great, my prey ... *teasing kiss*	0
+1609	erutupalam thandiyachu	0
+1610	freemsg:feelin kinda lnly hope u like 2 keep me company! jst got a cam moby wanna c my pic?txt or reply date to 82242 msg150p 2rcv hlp 08712317606 stop to 82242	1
+1611	2 laptop... i noe infra but too slow lar... i wan fast one	0
+1612	so the sun is anti sleep medicine.	0
+1613	you are chosen to receive a ??350 award! pls call claim number 09066364311 to collect your award which you are selected to receive as a valued mobile customer.	1
+1614	oh baby of the house. how come you dont have any new pictures on facebook	0
+1615	hi harish's rent has been transfred to ur acnt.	0
+1616	lolnice. i went from a fish to ..water.?	0
+1617	i've been searching for the right words to thank you for this breather. i promise i wont take your help for granted and will fulfil my promise. you have been wonderful and a blessing at all times.	0
+1618	thanks a lot for your wishes on my birthday. thanks you for making my birthday truly memorable.	0
+1619	can ?_ all decide faster cos my sis going home liao..	0
+1620	hello.how u doing?what u been up 2?when will u b moving out of the flat, cos i will need to arrange to pick up the lamp, etc. take care. hello caroline!	0
+1621	wat makes some people dearer is not just de happiness dat u feel when u meet them but de pain u feel when u miss dem!!!	0
+1622	when are you guys leaving?	0
+1623	oic... then better quickly go bathe n settle down...	0
+1624	send a logo 2 ur lover - 2 names joined by a heart. txt love name1 name2 mobno eg love adam eve 07123456789 to 87077 yahoo! pobox36504w45wq txtno 4 no ads 150p.	1
+1625	yo do you know anyone  &lt;#&gt;  or otherwise able to buy liquor? our guy flaked and right now if we don't get a hold of somebody its just 4 loko all night	0
+1626	leave it. u will always be ignorant.	0
+1627	just making dinner, you ?	0
+1628	thats cool. where should i cum? on you or in you? :)	0
+1629	&lt;#&gt;  in mca. but not conform.	0
+1630	i don't think i can get away for a trek that long with family in town, sorry	0
+1631	congrats! nokia 3650 video camera phone is your call 09066382422 calls cost 150ppm ave call 3mins vary from mobiles 16+ close 300603 post bcm4284 ldn wc1n3xx	1
+1632	i just cooked a rather nice salmon a la you	0
+1633	ugh. gotta drive back to sd from la. my butt is sore.	0
+1634	it'll be tough, but i'll do what i have to	0
+1635	mobile club: choose any of the top quality items for your mobile. 7cfca1a	1
+1636	your 2004 account for 07xxxxxxxxx shows 786 unredeemed points. to claim call 08719181259 identifier code: xxxxx expires 26.03.05	1
+1637	gud ni8 dear..slp well..take care..swt dreams..muah..	0
+1638	i (career tel) have added u as a contact on indyarocks.com to send free sms. to remove from phonebook - sms no to  &lt;#&gt;	0
+1639	10 min later k...	0
+1640	our records indicate u maybe entitled to 5000 pounds in compensation for the accident you had. to claim 4 free reply with claim to this msg. 2 stop txt stop	1
+1641	where are you?when wil you reach here?	0
+1642	they said if its gonna snow, it will start around 8 or 9 pm tonite! they are predicting an inch of accumulation.	0
+1643	urgent! we are trying to contact u. todays draw shows that you have won a ??800 prize guaranteed. call 09050001808 from land line. claim m95. valid12hrs only	1
+1644	okay. no no, just shining on. that was meant to be signing, but that sounds better.	0
+1645	yay! you better not have told that to 5 other girls either.	0
+1646	refused a loan? secured or unsecured? can't get credit? call free now 0800 195 6669 or text back 'help' & we will!	1
+1647	happy or sad , one thing about past is- \its no more\" good morning :-):-)."	0
+1648	hmm well, night night	0
+1649	well then you have a great weekend!	0
+1650	does daddy have a bb now.	0
+1651	camera - you are awarded a sipix digital camera! call 09061221066 fromm landline. delivery within 28 days.	1
+1652	so how many days since then?	0
+1653	i wud never mind if u dont miss me or if u dont need me.. but u wil really hurt me wen u need me &amp; u dont tell me......... take care:-)	0
+1654	you intrepid duo you! have a great time and see you both soon.	0
+1655	come to medical college at 7pm ......forward it da	0
+1656	gr8 poly tones 4 all mobs direct 2u rply with poly title to 8007 eg poly breathe1 titles: crazyin, sleepingwith, finest, ymca :getzed.co.uk pobox365o4w45wq 300p	1
+1657	once a fishrman woke early in d mrng. it was very dark. he waited a while &amp; found a sack ful of stones. he strtd throwin thm in2 d sea 2 pass time. atlast he had jus 1stone, sun rose up &amp; he found out tht those r nt stones, those were diamonds. moral:\dont wake up early in d mrng'' good night"	0
+1658	hi da:)how is the todays class?	0
+1659	i though we shd go out n have some fun so bar in town or something ??? sound ok?	0
+1660	free msg. sorry, a service you ordered from 81303 could not be delivered as you do not have sufficient credit. please top up to receive the service.	1
+1661	k i'll call you when i'm close	0
+1662	ya ok, then had dinner?	0
+1663	will you like to be spoiled? :)	0
+1664	good night. am going to sleep.	0
+1665	?? predict wat time ?_'ll finish buying?	0
+1666	take something for pain. if it moves however to any side in the next 6hrs see a doctor.	0
+1667	pls call me da. what happen.	0
+1668	yes i know the cheesy songs from frosty the snowman :)	0
+1669	are you still playing with gautham?	0
+1670	i can make it up there, squeezed  &lt;#&gt;  bucks out of my dad	0
+1671	i get out of class in bsn in like  &lt;#&gt;  minutes, you know where advising is?	0
+1672	think + da. you wil do.	0
+1673	thats cool princess! i will cover your face in hot sticky cum :)	0
+1674	yar lor he wan 2 go c horse racing today mah, so eat earlier lor. i ate chicken rice. u?	0
+1675	hello, my love! how goes that day ? i wish your well and fine babe and hope that you find some job prospects. i miss you, boytoy ... *a teasing kiss*	0
+1676	going on nothing great.bye	0
+1677	yeah work is fine, started last week, all the same stuff as before, dull but easy and guys are fun!	0
+1678	u 447801259231 have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09058094597	1
+1679	now? i'm going out 4 dinner soon..	0
+1680	i'll reach in ard 20 mins ok...	0
+1681	i accidentally deleted the message. resend please.	0
+1682	see you there!	0
+1683	probably, want to pick up more?	0
+1684	there is no sense in my foot and penis.	0
+1685	juz go google n search 4 qet...	0
+1686	my friends use to call the same.	0
+1687	i tagged my friends that you seemed to count as your friends.	0
+1688	aight what time you want me to come up?	0
+1689	i liked the new mobile	0
+1690	rightio. 11.48 it is then. well arent we all up bright and early this morning.	0
+1691	hiya hows it going in sunny africa? hope u r avin a good time. give that big old silver back a big kiss from me.	0
+1692	great! how is the office today?	0
+1693	1. tension face 2. smiling face 3. waste face 4. innocent face 5.terror face 6.cruel face 7.romantic face 8.lovable face 9.decent face  &lt;#&gt; .joker face.	0
+1694	can you do a mag meeting this avo at some point?	0
+1695	man this bus is so so so slow. i think you're gonna get there before me	0
+1696	ard 515 like dat. y?	0
+1697	can you pls send me that company name. in saibaba colany	0
+1698	how much i gave to you. morning.	0
+1699	nah man, my car is meant to be crammed full of people	0
+1700	get your garden ready for summer with a free selection of summer bulbs and seeds worth ??33:50 only with the scotsman this saturday. to stop go2 notxt.co.uk	1
+1701	i need details about that online job.	0
+1702	yes:)sura in sun tv.:)lol.	0
+1703	is that seriously how you spell his name?	0
+1704	have a great trip to india. and bring the light to everyone not just with the project but with everyone that is lucky to see you smile. bye. abiola	0
+1705	many times we lose our best ones bcoz we are	0
+1706	get a brand new mobile phone by being an agent of the mob! plus loads more goodies! for more info just text mat to 87021.	1
+1707	if you don't, your prize will go to another customer. t&c at www.t-c.biz 18+ 150p/min polo ltd suite 373 london w1j 6hl please call back if busy	1
+1708	you have to pls make a note of all she.s exposed to. also find out from her school if anyone else was vomiting. is there a dog or cat in the house? let me know later.	0
+1709	u really pig leh sleep so much. my dad wake me up at 10 smth 2 eat lunch today.	0
+1710	yeah we wouldn't leave for an hour at least, how's 4 sound?	0
+1711	, ,  and  picking them up from various points | going 2 yeovil | and they will do the motor project 4 3 hours | and then u take them home. || 12 2 5.30 max. || very easy	0
+1712	a pure hearted person can have a wonderful smile that makes even his/her enemies to feel guilty for being an enemy.. so catch the world with your smile..:) goodmorning &amp; have a smiley sunday..:)	0
+1713	ok not a problem will get them a taxi. c ing  tomorrow and tuesday. on tuesday think we r all going to the cinema.	0
+1714	hi darlin im missin u hope you are having a good time. when are u back and what time if u can give me a call at home. jess xx	0
+1715	private! your 2004 account statement for 07742676969 shows 786 unredeemed bonus points. to claim call 08719180248 identifier code: 45239 expires	1
+1716	waaaat?? lololo ok next time then!	0
+1717	beautiful truth : expression of the face could be seen by everyone... but the depression of heart could be understood only by the loved ones.. gud ni8;-)	0
+1718	then why you not responding	0
+1719	?? v ma fan...	0
+1720	i had it already..sabarish asked me to go..	0
+1721	tonight? yeah, i'd be down for that	0
+1722	thanks for looking out for me. i really appreciate.	0
+1723	unless it's a situation where you go gurl would be more appropriate	0
+1724	free>ringtone! reply real or poly eg real1 1. pushbutton 2. dontcha 3. babygoodbye 4. golddigger 5. webeburnin 1st tone free and 6 more when u join for ??3/wk	1
+1725	you have won a guaranteed ??1000 cash or a ??2000 prize. to claim yr prize call our customer service representative on 08714712412 between 10am-7pm cost 10p	1
+1726	the sign of maturity is not when we start saying big things.. but actually it is, when we start understanding small things... *have a nice evening* bslvyl	0
+1727	probably money worries. things are coming due and i have several outstanding invoices for work i did two and three months ago.	0
+1728	do you know when the result.	0
+1729	see, i knew giving you a break a few times woul lead to you always wanting to miss curfew. i was gonna gibe you 'til one, but a midnight movie is not gonna get out til after 2. you need to come home. you need to getsleep and, if anything, you need to b studdying ear training.	0
+1730	i will reach office around  &lt;decimal&gt; . &amp; my mobile have problem. you cann't get my voice. so call you asa i'll free	0
+1731	still otside le..u come 2morrow maga..	0
+1732	?? still attending da talks?	0
+1733	you have got tallent but you are wasting.	0
+1734	ok lar i double check wif da hair dresser already he said wun cut v short. he said will cut until i look nice.	0
+1735	cthen i thk shd b enuff.. still got conclusion n contents pg n references.. i'll b doing da contents pg n cover pg..	0
+1736	in other news after hassling me to get him weed for a week andres has no money. haughaighgtujhyguj	0
+1737	asked 3mobile if 0870 chatlines inclu in free mins. india cust servs sed yes. l8er got mega bill. 3 dont giv a shit. bailiff due in days. i o ??250 3 want ??800	1
+1738	yes. i come to nyc for audiitions and am trying to relocate.	0
+1739	cheers for the card ... is it that time of year already?	0
+1740	i can't, i don't have her number!	0
+1741	\hey kate	0
+1742	said kiss, kiss, i can't do the sound effects! he is a gorgeous man isn't he! kind of person who needs a smile to brighten his day!	0
+1743	have a safe trip to nigeria. wish you happiness and very soon company to share moments with	0
+1744	dai what this da.. can i send my resume to this id.	0
+1745	u should have made an appointment	0
+1746	\for the most sparkling shopping breaks from 45 per person; call 0121 2025050 or visit www.shortbreaks.org.uk\""	1
+1747	mm umma ask vava also to come tell him can play later together	0
+1748	oh that was a forwarded message. i thought you send that to me	0
+1749	dont know you bring some food	0
+1750	r u sure they'll understand that! wine * good idea just had a slurp!	0
+1751	anything lor but toa payoh got place 2 walk meh...	0
+1752	bring home some wendy =d	0
+1753	y bishan lei... i tot ?_ say lavender?	0
+1754	ok i shall talk to him	0
+1755	k. i will sent it again	0
+1756	new textbuddy chat 2 horny guys in ur area 4 just 25p free 2 receive search postcode or at gaytextbuddy.com. txt one name to 89693	1
+1757	wait . i will msg after  &lt;#&gt;  min.	0
+1758	noooooooo please. last thing i need is stress. for once in your life be fair.	0
+1759	free msg: ringtone!from: http://tms. widelive.com/index. wml?id=1b6a5ecef91ff9*37819&first=true18:0430-jul-05	1
+1760	hey. what happened? u switch off ur cell d whole day. this isnt good. now if u do care, give me a call tomorrow.	0
+1761	except theres a chick with huge boobs.	0
+1762	did you get any gift? this year i didnt get anything. so bad	0
+1763	what's up my own oga. left my phone at home and just saw ur messages. hope you are good. have a great weekend.	0
+1764	maybe?! say hi to  and find out if  got his card. great escape or wetherspoons?	0
+1765	wat happened to the cruise thing	0
+1766	you have an important customer service announcement from premier. call freephone 0800 542 0578 now!	1
+1767	but we havent got da topic yet rite?	0
+1768	u were outbid by simonwatson5120 on the shinco dvd plyr. 2 bid again, visit sms. ac/smsrewards 2 end bid notifications, reply end out	1
+1769	sms. ac jsco: energy is high, but u may not know where 2channel it. 2day ur leadership skills r strong. psychic? reply ans w/question. end? reply end jsco	1
+1770	just checking in on you. really do miss seeing jeremiah. do have a great month	0
+1771	you know what hook up means right?	0
+1772	well good morning mr . hows london treatin' ya treacle?	0
+1773	raji..pls do me a favour. pls convey my birthday wishes to nimya. pls. today is her birthday.	0
+1774	ok...	0
+1775	please sen :)my kind advice :-)please come here and try:-)	0
+1776	on the road so cant txt	0
+1777	cool, what time you think you can get here?	0
+1778	i had askd u a question some hours before. its answer	0
+1779	gudnite....tc...practice going on	0
+1780	im gonna miss u so much	0
+1781	well the general price is  &lt;#&gt; /oz, let me know if/when/how much you want	0
+1782	we are hoping to get away by 7, from langport. you still up for town tonight?	0
+1783	gd luck 4 ur exams :-)	0
+1784	good friends care for each other.. close friends understand each other... and true friends stay forever beyond words, beyond time. gud ni8	0
+1785	after completed degree. there is no use in joining finance.	0
+1786	your board is working fine. the issue of overheating is also reslove. but still software inst is pending. i will come around 8'o clock.	0
+1787	his bday real is in april .	0
+1788	convey my regards to him	0
+1789	sorry, i'll call later	0
+1790	just got part nottingham - 3 hrs 63miles. good thing i love my man so much, but only doing 40mph. hey ho	0
+1791	i cant keep talking to people if am not sure i can pay them if they agree to price. so pls tell me what you want to really buy and how much you are willing to pay	0
+1792	babe, i'm answering you, can't you see me ? maybe you'd better reboot ym ... i got the photo ... it's great !	0
+1793	i dont knw pa, i just drink milk..	0
+1794	hi. i'm sorry i missed your call. can you pls call back.	0
+1795	i just got home babe, are you still awake ?	0
+1796	no dice, art class 6 thru 9 :( thanks though. any idea what time i should come tomorrow?	0
+1797	i'll see if i can swing by in a bit, got some things to take care of here firsg	0
+1798	you best watch what you say cause i get drunk as a motherfucker	0
+1799	there are some nice pubs near here or there is frankie n bennys near the warner cinema?	0
+1800	dear umma she called me now :-)	0
+1801	i am in escape theatre now. . going to watch kavalan in a few minutes	0
+1802	are you still getting the goods.	0
+1803	hey. for me there is no leave on friday. wait i will ask my superior and tell you..	0
+1804	what time do u get out?	0
+1805	its cool but tyler had to take off so we're gonna buy for him and drop it off at his place later tonight. our total order is a quarter, you got enough?	0
+1806	hey , is * rite u put ??10 evey mnth is that all?	0
+1807	book which lesson? then you msg me... i will call up after work or sth... i'm going to get specs. my membership is px3748	0
+1808	im sorry bout last nite it wasn??t ur fault it was me, spouse it was pmt or sumthin! u 4give me? i think u shldxxxx	0
+1809	:( but your not here....	0
+1810	amazing : if you rearrange these letters it gives the same meaning... dormitory = dirty room astronomer = moon starer the eyes = they see election results = lies lets recount mother-in-law = woman hitler eleven plus two =twelve plus one its amazing... !:-)	0
+1811	nothing lor... a bit bored too... then y dun u go home early 2 sleep today...	0
+1812	yeah right! i'll bring my tape measure fri!	0
+1813	i'll meet you in the lobby	0
+1814	i'm in class. will holla later	0
+1815	hi its kate can u give me a ring asap xxx	0
+1816	ok i vl..do u know i got adsense approved..	0
+1817	not tonight mate. catching up on some sleep. this is my new number by the way.	0
+1818	alright, i'll head out in a few minutes, text me where to meet you	0
+1819	sorry, was in the bathroom, sup	0
+1820	i'm leaving my house now...	0
+1821	jordan got voted out last nite!	0
+1822	there's no point hangin on to mr not right if he's not makin u happy	0
+1823	ok...	0
+1824	i liked your new house	0
+1825	cool, i'll text you in a few	0
+1826	ill be at yours in about 3 mins but look out for me	0
+1827	hey i am really horny want to chat or see me naked text hot to 69698 text charged at 150pm to unsubscribe text stop 69698	1
+1828	got fujitsu, ibm, hp, toshiba... got a lot of model how to say...	0
+1829	i dont want to hear philosophy. just say what happen	0
+1830	today am going to college so am not able to atten the class.	0
+1831	i'll let you know when it kicks in	0
+1832	thank you so much. when we skyped wit kz and sura, we didnt get the pleasure of your company. hope you are good. we've given you ultimatum oh! we are countin down to aburo. enjoy!	0
+1833	lol boo i was hoping for a laugh	0
+1834	aah! a cuddle would be lush! i'd need lots of tea and soup before any kind of fumbling!	0
+1835	sounds gd... haha... can... wah, u yan jiu so fast liao...	0
+1836	solve d case : a man was found murdered on  &lt;decimal&gt; . &lt;#&gt;  afternoon. 1,his wife called police. 2,police questioned everyone. 3,wife: sir,i was sleeping, when the murder took place. 4.cook: i was cooking. 5.gardener: i was picking vegetables. 6.house-maid: i went 2 d post office. 7.children: we went 2 play. 8.neighbour: we went 2 a marriage. police arrested d murderer immediately. who's it? reply with reason, if u r brilliant.	0
+1837	stupid.its not possible	0
+1838	o i played smash bros  &lt;#&gt;  religiously.	0
+1839	k.i did't see you.:)k:)where are you now?	0
+1840	free for 1st week! no1 nokia tone 4 ur mob every week just txt nokia to 8007 get txting and tell ur mates www.getzed.co.uk pobox 36504 w45wq norm150p/tone 16+	1
+1841	kate jackson rec center before 7ish, right?	0
+1842	sorry, it's a lot of friend-of-a-friend stuff, i'm just now about to talk to the actual guy who wants to buy	0
+1843	okay but i thought you were the expert	0
+1844	?? still got lessons?  ?? in sch?	0
+1845	eerie nokia tones 4u, rply tone title to 8007 eg tone dracula to 8007 titles: ghost, addamsfa, munsters, exorcist, twilight www.getzed.co.uk pobox36504w45wq 150p	1
+1846	ok...	0
+1847	k.:)you are the only girl waiting in reception ah?	0
+1848	i love to cuddle! i want to hold you in my strong arms right now...	0
+1849	she was supposed to be but couldn't make it, she's still in town though	0
+1850	she's fine. sends her greetings	0
+1851	private! your 2003 account statement for shows 800 un-redeemed s. i. m. points. call 08715203694 identifier code: 40533 expires 31/10/04	1
+1852	u???ve bin awarded ??50 to play 4 instant cash. call 08715203028 to claim. every 9th player wins min ??50-??500. optout 08718727870	1
+1853	lol ok ill try to send. be warned sprint is dead slow. you'll prolly get it tomorrow	0
+1854	spook up your mob with a halloween collection of a logo & pic message plus a free eerie tone, txt card spook to 8007 zed 08701417012150p per logo/pic	1
+1855	hi darlin i finish at 3 do u 1 2 pick me up or meet me? text back on this number luv kate xxx	0
+1856	eh ur laptop got no stock lei... he say mon muz come again to take a look c got a not...	0
+1857	huh but i got lesson at 4 lei n i was thinkin of going to sch earlier n i tot of parkin at kent vale...	0
+1858	rct' thnq adrian for u text. rgds vatian	1
+1859	that's cause your old. i live to be high.	0
+1860	did he just say somebody is named tampa	0
+1861	private! your 2003 account statement for 078	1
+1862	i dont know exactly could you ask chechi.	0
+1863	its a part of checking iq	0
+1864	sounds like something that someone testing me would sayy	0
+1865	ok. how many should i buy.	0
+1866	change again... it's e one next to escalator...	0
+1867	good day to you too.pray for me.remove the teeth as its painful maintaining other stuff.	0
+1868	x2  &lt;#&gt; . are you going to get that	0
+1869	ur going 2 bahamas! callfreefone 08081560665 and speak to a live operator to claim either bahamas cruise of??2000 cash 18+only. to opt out txt x to 07786200117	1
+1870	did u find out what time the bus is at coz i need to sort some stuff out.	0
+1871	i wont do anything de.	0
+1872	thanks for the temales it was wonderful. thank. have a great week.	0
+1873	y lei?	0
+1874	hey.. something came up last min.. think i wun be signing up tmr.. hee	0
+1875	oh... icic... k lor, den meet other day...	0
+1876	is ur lecture over?	0
+1877	hi, mobile no.  &lt;#&gt;  has added you in their contact list on www.fullonsms.com it s a great place to send free sms to people for more visit fullonsms.com	0
+1878	aight, call me once you're close	0
+1879	your brother is a genius	0
+1880	it,,s a taxt massage....tie-pos argh ok! lool!	0
+1881	yes:)here tv is always available in work place..	0
+1882	you call times job today ok umma and ask them to speed up	0
+1883	both :) i shoot big loads so get ready!	0
+1884	sorry, i'll call later	0
+1885	so are you guys asking that i get that slippers again or its gone with last year	0
+1886	pandy joined 4w technologies today.he got job..	0
+1887	wow v v impressed. have funs shopping!	0
+1888	thanks for sending this mental ability question..	0
+1889	did you catch the bus ? are you frying an egg ? did you make a tea? are you eating your mom's left over dinner ? do you feel my love ?	0
+1890	no go. no openings for that room 'til after thanksgiving without an upcharge.	0
+1891	and popping &lt;#&gt; ibuprofens was no help.	0
+1892	pathaya enketa maraikara pa'	0
+1893	k, fyi i'm back in my parents' place in south tampa so i might need to do the deal somewhere else	0
+1894	don know:)this week i'm going to tirunelvai da.	0
+1895	my sort code is  and acc no is . the bank is natwest. can you reply to confirm i've sent this to the right person!	0
+1896	hey company elama po mudyadhu.	0
+1897	time n smile r the two crucial things in our life. sometimes time makes us to forget smile, and sometimes someone's smile makes us to forget time gud noon	0
+1898	guess what! somebody you know secretly fancies you! wanna find out who it is? give us a call on 09065394514 from landline datebox1282essexcm61xn 150p/min 18	1
+1899	cab is available.they pick up and drop at door steps.	0
+1900	eatin my lunch...	0
+1901	dear good morning how you feeling dear	0
+1902	tick, tick, tick ... babe	0
+1903	i wish! i don't think its gonna snow that much. but it will be more than those flurries we usually get that melt before they hit the ground. eek! we haven't had snow since &lt;#&gt; before i was even born!	0
+1904	oh... lk tt den we take e one tt ends at cine lor... dun wan yogasana oso can...	0
+1905	100 dating service cal;l 09064012103 box334sk38ch	1
+1906	sad story of a man - last week was my b'day. my wife did'nt wish me. my parents forgot n so did my kids . i went to work. even my colleagues did not wish. as i entered my cabin my pa said, '' happy b'day boss !!''. i felt special. she askd me 4 lunch. after lunch she invited me to her apartment. we went there. she said,'' do u mind if i go into the bedroom for a minute ? '' ''ok'', i sed in a sexy mood. she came out 5 minuts latr wid a cake...n my wife, my parents, my kidz, my friends n my colleagues. all screaming.. surprise !! and i was waiting on the sofa.. ... ..... ' naked...!	0
+1907	tell me whos this pls:-)	0
+1908	mom wants to know where you at	0
+1909	that's y i said it's bad dat all e gals know u... wat u doing now?	0
+1910	haha i heard that, text me when you're around	0
+1911	thanx a lot 4 ur help!	0
+1912	yep, at derek's house now, see you sunday &lt;3	0
+1913	\hey! do u fancy meetin me at 4 at cha ?? hav a lil beverage on me. if not txt or ring me and we can meet up l8r. quite tired got in at 3 v.pist ;) love pete x x x\""	0
+1914	so its to be poking man everyday that they teach you in canada abi! how are you. just saying hi.	0
+1915	thank you. do you generally date the brothas?	0
+1916	my sister cleared two round in birla soft yesterday.	0
+1917	me too! have a lovely night xxx	0
+1918	as per your request 'maangalyam (alaipayuthe)' has been set as your callertune for all callers. press *9 to copy your friends callertune	0
+1919	i am late,so call you tomorrow morning.take care sweet dreams....u and me...ummifying...bye.	0
+1920	r u saying i should re order the slippers cos i had to pay for returning it.	0
+1921	there're some people by mu, i'm at the table by lambda	0
+1922	what's happening with you. have you gotten a job and have you begun registration for permanent residency	0
+1923	oh god..taken the teeth?is it paining	0
+1924	in which place i can get rooms cheap:-)	0
+1925	hey cutie. how goes it? here in wales its kinda ok. there is like hills and shit but i still avent killed myself.	0
+1926	tell your friends what you plan to do on valentines day @ &lt;url&gt;	0
+1927	hey pple...$700 or $900 for 5 nights...excellent location wif breakfast hamper!!!	0
+1928	sorry, i'll call later	0
+1929	i miss you so much i'm so desparate i have recorded the message you left for me the other day and listen to it just to hear the sound of your voice. i love you	0
+1930	i've been barred from all b and q stores for life!?this twat in orange dungerees came up to me and asked if i wanted decking? so i got the first punch in!!	0
+1931	ok.	0
+1932	evry emotion dsn't hav words.evry wish dsn't hav prayrs.. if u smile,d world is wit u.othrwise even d drop of tear dsn't lik 2 stay wit u.so b happy.. good morning, keep smiling:-)	0
+1933	now only i reached home. . . i am very tired now. . i will come tomorro	0
+1934	i jus hope its true that  missin me cos i'm really missin him! you haven't done anything to feel guilty about, yet.	0
+1935	went fast asleep dear.take care.	0
+1936	if i start sending blackberry torch to nigeria will you find buyer for me?like 4a month. and tell dad not to buy bb from anyone oh.	0
+1937	hot live fantasies call now 08707509020 just 20p per min ntt ltd, po box 1327 croydon cr9 5wb 0870..k	1
+1938	gent! we are trying to contact you. last weekends draw shows that you won a ??1000 prize guaranteed. call 09064012160. claim code k52. valid 12hrs only. 150ppm	1
+1939	miserable. they don't tell u that the side effects of birth control are massive gut wrenching cramps for the first 2 months. i didn't sleep at all last night.	0
+1940	u repeat e instructions again. wat's e road name of ur house?	0
+1941	no worries, hope photo shoot went well. have a spiffing fun at workage.	0
+1942	ugh y can't u just apologize, admit u were wrong and ask me to take u back?	0
+1943	u are subscribed to the best mobile content service in the uk for ??3 per ten days until you send stop to 83435. helpline 08706091795.	1
+1944	aight well keep me informed	0
+1945	yunny... i'm goin to be late	0
+1946	chinatown got porridge, claypot rice, yam cake, fishhead beehoon... either we eat cheap den go cafe n tok or go nydc or somethin...	0
+1947	going to take your babe out ?	0
+1948	lara said she can loan me  &lt;#&gt; .	0
+1949	i see the letter b on my car	0
+1950	fine if that??s the way u feel. that??s the way its gota b	0
+1951	yo, you at jp and hungry like a mofo?	0
+1952	i'm on my way home. went to change batt 4 my watch then go shop a bit lor.	0
+1953	then get some cash together and i'll text jason	0
+1954	your opinion about me? 1. over 2. jada 3. kusruthi 4. lovable 5. silent 6. spl character 7. not matured 8. stylish 9. simple pls reply..	0
+1955	hanging out with my brother and his family	0
+1956	your next amazing xxx picsfree1 video will be sent to you enjoy! if one vid is not enough for 2day text back the keyword picsfree1 to get the next video.	1
+1957	hi this is yijue, can i meet u at 11 tmr?	0
+1958	then u go back urself lor...	0
+1959	if you hear a loud scream in about &lt;#&gt; minutes its cause my gyno will be shoving things up me that don't belong :/	0
+1960	carlos is taking his sweet time as usual so let me know when you and patty are done/want to smoke and i'll tell him to haul ass	0
+1961	sunshine hols. to claim ur med holiday send a stamped self address envelope to drinks on us uk, po box 113, bray, wicklow, eire. quiz starts saturday! unsub stop	1
+1962	alright tyler's got a minor crisis and has to be home sooner than he thought so be here asap	0
+1963	aah bless! how's your arm?	0
+1964	you can jot down things you want to remember later.	0
+1965	em, its olowoyey@ usc.edu have a great time in argentina. not sad about secretary, everything is a blessing	0
+1966	huh but i cant go 2 ur house empty handed right?	0
+1967	message:some text missing* sender:name missing* *number missing *sent:date missing *missing u a lot thats y everything is missing sent via fullonsms.com	0
+1968	great new offer - double mins & double txt on best orange tariffs and get latest camera phones 4 free! call mobileupd8 free on 08000839402 now! or 2stoptxt t&cs	1
+1969	mum, hope you are having a great day. hoping this text meets you well and full of life. have a great day. abiola	0
+1970	fancy a shag? i do.interested? sextextuk.com txt xxuk suzy to 69876. txts cost 1.50 per msg. tncs on website. x	1
+1971	i want to tell you how bad i feel that basically the only times i text you lately are when i need drugs	0
+1972	yeah i don't see why not	0
+1973	yeah i'll try to scrounge something up	0
+1974	fr'ndship is like a needle of a clock. though v r in d same clock, v r nt able 2 met. evn if v meet,itz only 4few seconds. bt v alwys stay conected. gud 9t;-)	0
+1975	hi di is yijue we're meeting at 7 pm at esaplanade tonight.	0
+1976	where are the garage keys? they aren't on the bookshelf	0
+1977	i wanted to ask ?_ to wait 4 me to finish lect. cos my lect finishes in an hour anyway.	0
+1978	i'm going for bath will msg you next  &lt;#&gt;  min..	0
+1979	wa, ur openin sentence very formal... anyway, i'm fine too, juz tt i'm eatin too much n puttin on weight...haha... so anythin special happened?	0
+1980	you are a winner you have been specially selected to receive ??1000 cash or a ??2000 award. speak to a live operator to claim call 087123002209am-7pm. cost 10p	1
+1981	no i don't have cancer. moms making a big deal out of a regular checkup aka pap smear	0
+1982	don know..he is watching film in computer..	0
+1983	we tried to contact you re your reply to our offer of a video handset? 750 anytime any networks mins? unlimited text? camcorder? reply or call 08000930705 now	1
+1984	sorry for the delay. yes masters	0
+1985	aight should i just plan to come up later tonight?	0
+1986	sorry, i'll call later	0
+1987	wife.how she knew the time of murder exactly	0
+1988	the sign of maturity is not when we start saying big things.. but actually it is, when we start understanding small things... *have a nice evening* bslvyl	0
+1989	s:)but he had some luck.2 catches put down:)	0
+1990	did i forget to tell you ? i want you , i need you, i crave you ... but most of all ... i love you my sweet arabian steed ... mmmmmm ... yummy	0
+1991	prakesh is there know.	0
+1992	you didnt complete your gist oh.	0
+1993	i'm in solihull, | do you want anything?	0
+1994	thank u. it better work out cause i will feel used otherwise	0
+1995	it's ok, at least armand's still around	0
+1996	no..few hours before.went to hair cut .	0
+1997	k ill drink.pa then what doing. i need srs model pls send it to my mail id pa.	0
+1998	private! your 2003 account statement for 07808247860 shows 800 un-redeemed s. i. m. points. call 08719899229 identifier code: 40411 expires 06/11/04	1
+1999	just forced myself to eat a slice. i'm really not hungry tho. this sucks. mark is getting worried. he knows i'm sick when i turn down pizza. lol	0
+2000	hi here. have birth at on the  to  at 8lb 7oz. mother and baby doing brilliantly.	0
+2001	ah, well that confuses things, doesnt it? i thought was friends with now. maybe i did the wrong thing but i already sort of invited -tho he may not come cos of money.	0
+2002	you have been specially selected to receive a 2000 pound award! call 08712402050 before the lines close. cost 10ppm. 16+. t&cs apply. ag promo	1
+2003	hey there babe, how u doin? wot u up 2 2nite love annie x.	0
+2004	hard but true: how much you show &amp;  express your love to someone....that much it will hurt when they leave you or you get seperated...!?????_?????ud evening...	0
+2005	where are you ? you said you would be here when i woke ... :-(	0
+2006	science tells that chocolate will melt under the sunlight. please don't walk under the sunlight. bcoz,i don't want to loss a sweet friend.	0
+2007	thanks for the vote. now sing along with the stars with karaoke on your mobile. for a free link just reply with sing now.	1
+2008	i think we're going to finn's now, come	0
+2009	not thought bout it... || drink in tap & spile at seven. || is that pub on gas st off broad st by canal. || ok?	0
+2010	even i cant close my eyes you are in me our vava playing umma :-d	0
+2011	shant disturb u anymore... jia you...	0
+2012	also andros ice etc etc	0
+2013	leave it wif me lar... ?? wan to carry meh so heavy... is da num 98321561 familiar to ?_?	0
+2014	do you know why god created gap between your fingers..? so that, one who is made for you comes &amp; fills those gaps by holding your hand with love..!	0
+2015	3 free tarot texts! find out about your love life now! try 3 for free! text chance to 85555 16 only! after 3 free, msgs ??1.50 each	1
+2016	thats cool! sometimes slow and gentle. sonetimes rough and hard :)	0
+2017	its too late:)but its k.wish you the same.	0
+2018	urgent! call 09066612661 from landline. your complementary 4* tenerife holiday or ??10,000 cash await collection sae t&cs po box 3 wa14 2px 150ppm 18+ sender: hol offer	1
+2019	i'm home. doc gave me pain meds says everything is fine.	0
+2020	26th of july	0
+2021	ou are guaranteed the latest nokia phone, a 40gb ipod mp3 player or a ??500 prize! txt word: collect to no: 83355! ibhltd ldnw15h 150p/mtmsgrcvd18	1
+2022	and now electricity just went out fml.	0
+2023	watching tv lor...	0
+2024	well boy am i glad g wasted all night at applebees for nothing	0
+2025	am i that much bad to avoid like this?	0
+2026	er, hello, things didn???t quite go to plan ??? is limping slowly home followed by aa and with exhaust hanging off	0
+2027	shb b ok lor... thanx...	0
+2028	i'm always on yahoo messenger now. just send the message to me and i.ll get it you may have to send it in the mobile mode sha but i.ll get it. and will reply.	0
+2029	at bruce b downs &amp; fletcher now	0
+2030	ok lor. anyway i thk we cant get tickets now cos like quite late already. u wan 2 go look 4 ur frens a not? darren is wif them now...	0
+2031	moji just informed me that you saved our lives. thanks.	0
+2032	once free call me sir.	0
+2033	most of the tiime when i don't let you hug me it's so i don't break into tears.	0
+2034	when you are big..| god will bring success.	0
+2035	wow! the boys r back. take that 2007 uk tour. win vip tickets & pre-book with vip club. txt club to 81303. trackmarque ltd info@vipclub4u.	1
+2036	hiya, sorry didn't hav signal. i haven't seen or heard from and neither has, which is unusual in itself! i'll put on the case and get him to sort it out! hugs and snogs.	0
+2037	fwiw the reason i'm only around when it's time to smoke is that because of gas i can only afford to be around when someone tells me to be and that apparently only happens when somebody wants to light up	0
+2038	today is \song dedicated day..\" which song will u dedicate for me? send this to all ur valuable frnds but first rply me..."	0
+2039	sent me ur email id soon	0
+2040	thanks honey. have a great day.	0
+2041	lol now i'm after that hot air balloon!	0
+2042	please call our customer service representative on freephone 0808 145 4742 between 9am-11pm as you have won a guaranteed ??1000 cash or ??5000 prize!	1
+2043	hello beautiful r u ok? i've kinda ad a row wiv and he walked out the pub?? i wanted a night wiv u miss u	0
+2044	been running but only managed 5 minutes and then needed oxygen! might have to resort to the roller option!	0
+2045	summers finally here! fancy a chat or flirt with sexy singles in yr area? to get matched up just reply summer now. free 2 join. optout txt stop help08714742804	1
+2046	delhi and chennai still silent.	0
+2047	hope you are having a great day.	0
+2048	nothing just getting msgs by dis name wit different no's..	0
+2049	hey... why dont we just go watch x men and have lunch... haha	0
+2050	not really dude, have no friends i'm afraid :(	0
+2051	plz note: if anyone calling from a mobile co. &amp; asks u to type # &lt;#&gt;  or # &lt;#&gt; . do not do so. disconnect the call,coz it iz an attempt of 'terrorist' to make use of the sim card no. itz confirmd by nokia n motorola n has been verified by cnn ibn.	0
+2052	how about clothes, jewelry, and trips?	0
+2053	74355 xmas iscoming & ur awarded either ??500 cd gift vouchers & free entry 2 r ??100 weekly draw txt music to 87066 tnc	1
+2054	you will be in the place of that man	0
+2055	if you're still up, maybe leave the credit card so i can get gas when i get back like he told me to	0
+2056	this single single answers are we fighting? plus i said am broke and you didnt reply	0
+2057	you still around? i could use a half-8th	0
+2058	u guys never invite me anywhere :(	0
+2059	this is hoping you enjoyed your game yesterday. sorry i've not been in touch but pls know that you are fondly bein thot off. have a great week. abiola	0
+2060	nothing will ever be easy. but don't be looking for a reason not to take a risk on life and love	0
+2061	reply with your name and address and you will receive by post a weeks completely free accommodation at various global locations www.phb1.com ph:08700435505150p	1
+2062	what you doing?how are you?	0
+2063	how i noe... did ?_ specify da domain as nusstu... ?? still in sch...	0
+2064	maybe westshore or hyde park village, the place near my house?	0
+2065	haha i think i did too	0
+2066	okie but i scared u say i fat... then u dun wan me already...	0
+2067	no probs hon! how u doinat the mo?	0
+2068	call 09094100151 to use ur mins! calls cast 10p/min (mob vary). service provided by aom, just gbp5/month. aom box61,m60 1er until u stop. ages 18+ only!	1
+2069	its a big difference.  &lt;#&gt;  versus  &lt;#&gt;  every  &lt;#&gt; hrs	0
+2070	ryder unsold.now gibbs.	0
+2071	thats cool. i am a gentleman and will treat you with dignity and respect.	0
+2072	just come home. i don't want u to be miserable	0
+2073	she doesnt need any test.	0
+2074	hi did u decide wot 2 get 4 his bday if not ill prob jus get him a voucher frm virgin or sumfing	0
+2075	say this slowly.? god,i love you &amp; i need you,clean my heart with your blood.send this to ten special people &amp; u c miracle tomorrow, do it,pls,pls do it...	0
+2076	that means get the door	0
+2077	goldviking (29/m) is inviting you to be his friend. reply yes-762 or no-762 see him: www.sms.ac/u/goldviking stop? send stop frnd to 62468	1
+2078	wat r u doing?	0
+2079	or better still can you catch her and let ask her if she can sell  &lt;#&gt;  for me.	0
+2080	good afternoon, my love! how goes that day ? i hope maybe you got some leads on a job. i think of you, boytoy and send you a passionate kiss from across the sea	0
+2081	i will come tomorrow di	0
+2082	update_now - 12mths half price orange line rental: 400mins...call mobileupd8 on 08000839402 or call2optout=j5q	1
+2083	well, i'm gonna finish my bath now. have a good...fine night.	0
+2084	i only work from mon to thurs but sat i cant leh... booked liao... which other day u free?	0
+2085	dear all, as we know  &lt;#&gt; th is the  &lt;#&gt; th birthday of our loving gopalettan. we are planning to give a small gift on that day. those who like to participate in that you are welcome. please contact our admin team for more details	0
+2086	as if i wasn't having enough trouble sleeping.	0
+2087	we regret to inform u that the nhs has made a mistake.u were never actually born.please report 2 yor local hospital 2b terminated.we r sorry 4 the inconvenience	0
+2088	tell me something. thats okay.	0
+2089	how much it will cost approx . per month.	0
+2090	oi when you gonna ring	0
+2091	where are you call me.	0
+2092	i haven't forgotten you, i might have a couple bucks to send you tomorrow, k? i love ya too	0
+2093	my painful personal thought- \i always try to keep everybody happy all the time. but nobody recognises me when i am alone\""	0
+2094	yes fine	0
+2095	goodmorning, today i am late for  &lt;decimal&gt; min.	0
+2096	hey tmr maybe can meet you at yck	0
+2097	ok i'm gonna head up to usf in like fifteen minutes	0
+2098	happy new year my dear brother. i really do miss you. just got your number and decided to send you this text wishing you only happiness. abiola	0
+2099	life style garments account no please.	0
+2100	carlos says we can pick up from him later so yeah we're set	0
+2101	where to get those?	0
+2102	okmail: dear dave this is your final notice to collect your 4* tenerife holiday or #5000 cash award! call 09061743806 from landline. tcs sae box326 cw25wx 150ppm	1
+2103	much better now thanks lol	0
+2104	sorry, i'll call later	0
+2105	come to mu, we're sorting out our narcotics situation	0
+2106	nutter. cutter. ctter. cttergg. cttargg. ctargg. ctagg. ie you	0
+2107	i don wake since. i checked that stuff and saw that its true no available spaces. pls call the embassy or send a mail to them.	0
+2108	please tell me not all of my car keys are in your purse	0
+2109	i could ask carlos if we could get more if anybody else can chip in	0
+2110	hi:)cts employee how are you?	0
+2111	i need to come home and give you some good lovin...	0
+2112	yup ok...	0
+2113	yeah, give me a call if you've got a minute	0
+2114	phony ??350 award - todays voda numbers ending xxxx are selected to receive a ??350 award. if you have a match please call 08712300220 quoting claim code 3100 standard rates app	1
+2115	went to ganesh dress shop	0
+2116	hi its jess i dont know if you are at work but call me when u can im at home all eve. xxx	0
+2117	don no da:)whats you plan?	0
+2118	yeah it's jus rite...	0
+2119	win the newest ??harry potter and the order of the phoenix (book 5) reply harry, answer 5 questions - chance to be the first among readers!	1
+2120	thinkin about someone is all good. no drugs for that	0
+2121	big brother alert! the computer has selected u for 10k cash or #150 voucher. call 09064018838. ntt po box cro1327 18+ bt landline cost 150ppm mobiles vary	1
+2122	oh k...i'm watching here:)	0
+2123	you have an important customer service announcement from premier.	1
+2124	hmv bonus special 500 pounds of genuine hmv vouchers to be won. just answer 4 easy questions. play now! send hmv to 86688 more info:www.100percent-real.com	1
+2125	she's fine. good to hear from you. how are you my dear? happy new year oh.	0
+2126	been up to ne thing interesting. did you have a good birthday? when are u wrking nxt? i started uni today.	0
+2127	we tried to contact you re our offer of new video phone 750 anytime any network mins half price rental camcorder call 08000930705 or reply for delivery wed	1
+2128	free for 1st week! no1 nokia tone 4 ur mob every week just txt nokia to 87077 get txting and tell ur mates. zed pobox 36504 w45wq norm150p/tone 16+	1
+2129	k and you're sure i don't have to have consent forms to do it :v	0
+2130	don't fret. i'll buy the ovulation test strips and send them to you. you wont get them til like march. can you send me your postal address.u'll be alright.okay.	0
+2131	urgent please call 09066612661 from landline. ??5000 cash or a luxury 4* canary islands holiday await collection. t&cs sae award. 20m12aq. 150ppm. 16+ ???	1
+2132	1st wk free! gr8 tones str8 2 u each wk. txt nokia on to 8007 for classic nokia tones or hit on to 8007 for polys. nokia/150p poly/200p 16+	1
+2133	?? say until like dat i dun buy ericsson oso cannot oredi lar...	0
+2134	you still at the game?	0
+2135	height of recycling: read twice- people spend time for earning money and the same money is spent for spending time!;-) good morning.. keep smiling:-)	0
+2136	k i'll be sure to get up before noon and see what's what	0
+2137	no it was cancelled yeah baby! well that sounds important so i understand my darlin give me a ring later on this fone love kate x	0
+2138	reverse is cheating. that is not mathematics.	0
+2139	i know i'm lacking on most of this particular dramastorm's details but for the most part i'm not worried about that	0
+2140	coffee cake, i guess...	0
+2141	today is accept day..u accept me as? brother sister lover dear1 best1 clos1 lvblefrnd jstfrnd cutefrnd lifpartnr belovd swtheart bstfrnd no rply means enemy	0
+2142	joy's father is john. then john is the name of joy's father. mandan	0
+2143	dai  &lt;#&gt;  naal eruku.	0
+2144	prepare to be pounded every night...	0
+2145	congrats! 1 year special cinema pass for 2 is yours. call 09061209465 now! c suprman v, matrix3, starwars3, etc all 4 free! bx420-ip4-5we. 150pm. dont miss out!	1
+2146	dear voucher holder, to claim this weeks offer, at your pc please go to http://www.wtlp.co.uk/text. ts&cs apply.	1
+2147	freemsg today's the day if you are ready! i'm horny & live in your town. i love sex fun & games! netcollex ltd 08700621170150p per msg reply stop to end	1
+2148	you want to go?	0
+2149	my phone	0
+2150	hiya stu wot u up 2.im in so much truble at home at moment evone hates me even u! wot the hell av i done now? y wont u just tell me text bck please luv dan	0
+2151	your dad is back in ph?	0
+2152	have you bookedthe hut? and also your time off? how are you by the way?	0
+2153	i didn't get the second half of that message	0
+2154	tmr then ?_ brin lar... aiya later i come n c lar... mayb ?_ neva set properly ?_ got da help sheet wif ?_...	0
+2155	watching tv now. i got new job :)	0
+2156	\yeh i am def up4 something sat	0
+2157	ranjith cal drpd deeraj and deepak 5min hold	0
+2158	?? got wat to buy tell us then ?_ no need to come in again.	0
+2159	do you mind if i ask what happened? you dont have to say if it is uncomfortable.	0
+2160	how are you enjoying this semester? take care brother.	0
+2161	sorry, i'll call later	0
+2162	i am hot n horny and willing i live local to you - text a reply to hear strt back from me 150p per msg netcollex ltdhelpdesk: 02085076972 reply stop to end	1
+2163	its a great day. do have yourself a beautiful one.	0
+2164	sorry, i can't help you on this.	0
+2165	and that is the problem. you walk around in \julianaland\" oblivious to what is going on around you. i say the same things constantly and they go in one ear and out the other while you go off doing whatever you want to do. it's not that you don't know why i'm upset--it's that you don't listen when i tell you what is going to upset me. then you want to be surprised when i'm mad."	0
+2166	i finished my lunch already. u wake up already?	0
+2167	our prashanthettan's mother passed away last night. pray for her and family.	0
+2168	yes i posted a couple of pics on fb. there's still snow outside too. i'm just waking up :)	0
+2169	yar i wanted 2 scold u yest but late already... i where got zhong se qing you? if u ask me b4 he ask me then i'll go out w u all lor. n u still can act so real.	0
+2170	oh god i am happy to see your message after 3 days	0
+2171	dear,shall mail tonite.busy in the street,shall update you tonite.things are looking ok.varunnathu edukkukayee raksha ollu.but a good one in real sense.	0
+2172	k.k.this month kotees birthday know?	0
+2173	jus chillaxin, what up	0
+2174	ok lor...	0
+2175	sorry, i can't help you on this.	0
+2176	+449071512431 urgent! this is the 2nd attempt to contact u!u have won ??1250 call 09071512433 b4 050703 t&csbcm4235wc1n3xx. callcost 150ppm mobilesvary. max??7. 50	1
+2177	free message activate your 500 free text messages by replying to this message with the word free for terms & conditions, visit www.07781482378.com	1
+2178	nothing really, just making sure everybody's up to speed	0
+2179	den wat will e schedule b lk on sun?	0
+2180	the greatest test of courage on earth is to bear defeat without losing heart....gn tc	0
+2181	i am getting threats from your sales executive shifad as i raised complaint against him. its an official message.	0
+2182	ding me on ya break fassyole! blacko from londn	0
+2183	i keep seeing weird shit and bein all \woah\" then realising it's actually reasonable and i'm all \"oh\""	0
+2184	k, my roommate also wants a dubsack and another friend may also want some so plan on bringing extra, i'll tell you when they know for sure	0
+2185	shall call now dear having food	0
+2186	?? give me some time to walk there.	0
+2187	dear friends, sorry for the late information. today is the birthday of our loving ar.praveesh. for more details log on to face book and see. its his number + &lt;#&gt; . dont miss a delicious treat.	0
+2188	good good, billy mates all gone. just been jogging, again! did enjoy concert?	0
+2189	when people see my msgs, they think iam addicted to msging... they are wrong, bcoz they don\'t know that iam addicted to my sweet friends..!! bslvyl	0
+2190	yavnt tried yet and never played original either	0
+2191	wanna do some art?! :d	0
+2192	piggy, r u awake? i bet u're still sleeping. i'm going 4 lunch now...	0
+2193	this message is from a great doctor in india:-): 1) do not drink appy fizz. it contains cancer causing age	0
+2194	beauty sleep can help ur pimples too.	0
+2195	happy birthday to you....dear.with lots of love.rakhesh nri	0
+2196	dunno he jus say go lido. same time 930.	0
+2197	in da car park	0
+2198	ball is moving a lot.will spin in last :)so very difficult to bat:)	0
+2199	i think if he rule tamilnadu..then its very tough for our people.	0
+2200	i want to sent  &lt;#&gt; mesages today. thats y. sorry if i hurts	0
+2201	i want to lick your pussy now...	0
+2202	how's it going? got any exciting karaoke type activities planned? i'm debating whether to play football this eve. feeling lazy though.	0
+2203	sure thing big man. i have hockey elections at 6, shouldn???t go on longer than an hour though	0
+2204	no objection. my bf not coming.	0
+2205	someonone you know is trying to contact you via our dating service! to find out who it could be call from your mobile or landline 09064015307 box334sk38ch	1
+2206	(and my man carlos is definitely coming by mu tonight, no excuses)	0
+2207	a boy was late 2 home. his father: \power of frndship\""	0
+2208	yup but it's not giving me problems now so mayb i'll jus leave it...	0
+2209	i know where the  &lt;#&gt;  is, i'll be there around 5	0
+2210	velly good, yes please!	0
+2211	i am in hospital da. . i will return home in evening	0
+2212	does she usually take fifteen fucking minutes to respond to a yes or no question	0
+2213	you have won a guaranteed ??200 award or even ??1000 cashto claim ur award call free on 08000407165 (18+) 2 stop getstop on 88222 php	1
+2214	by march ending, i should be ready. but will call you for sure. the problem is that my capital never complete. how far with you. how's work and the ladies	0
+2215	sports fans - get the latest sports news str* 2 ur mobile 1 wk free plus a free tone txt sport on to 8007 www.getzed.co.uk 0870141701216+ norm 4txt/120p	1
+2216	whatever, im pretty pissed off.	0
+2217	yo theres no class tmrw right?	0
+2218	beautiful tomorrow never comes.. when it comes, it's already today.. in the hunt of beautiful tomorrow don't waste your wonderful today.. goodmorning:)	0
+2219	if anyone calls for a treadmill say you'll buy it. make sure its working. i found an ad on craigslist selling for $ &lt;#&gt; .	0
+2220	usually the person is unconscious that's in children but in adults they may just behave abnormally. i.ll call you now	0
+2221	is toshiba portege m100 gd?	0
+2222	sad story of a man - last week was my b'day. my wife did'nt wish me. my parents forgot n so did my kids . i went to work. even my colleagues did not wish.	0
+2223	oh dang! i didn't mean o send that to you! lol!	0
+2224	wat time r ?_ going to xin's hostel?	0
+2225	thanx 4 e brownie it's v nice...	0
+2226	mmmmm ... it was sooooo good to wake to your words this morning, my love!! mmmm fuck ... i love you too, my lion ... *devouring kiss from across the sea*	0
+2227	is fujitsu s series lifebook good?	0
+2228	you are a very very very very bad girl. or lady.	0
+2229	yup he msg me: is tat yijue? then i tot it's my group mate cos we meeting today mah... i'm askin if ?_ leaving earlier or wat mah cos mayb ?_ haf to walk v far...	0
+2230	hey hun-onbus goin 2 meet him. he wants 2go out 4a meal but i donyt feel like it cuz have 2 get last bus home!but hes sweet latelyxxx	0
+2231	interflora - ??it's not too late to order interflora flowers for christmas call 0800 505060 to place your order before midnight tomorrow.	1
+2232	sorry i missed your call let's talk when you have the time. i'm on 07090201529	1
+2233	congratulations ur awarded either ??500 of cd gift vouchers & free entry 2 our ??100 weekly draw txt music to 87066 tncs www.ldew.com1win150ppmx3age16	1
+2234	are you sure you don't mean \get here	0
+2235	cos daddy arranging time c wat time fetch ?_ mah...	0
+2236	nope... c ?_ then...	0
+2237	someone has contacted our dating service and entered your phone becausethey fancy you! to find out who it is call from a landline 09058098002. pobox1, w14rg 150p	1
+2238	so no messages. had food?	0
+2239	your opinion about me? 1. over 2. jada 3. kusruthi 4. lovable 5. silent 6. spl character 7. not matured 8. stylish 9. simple pls reply..	0
+2240	also north carolina and texas atm, you would just go to the gre site and pay for the test results to be sent.	0
+2241	yes but can we meet in town cos will go to gep and then home. you could text at bus stop. and don't worry we'll have finished by march ??_ ish!	0
+2242	yes. last  practice	0
+2243	dnt worry...use ice pieces in a cloth pack.also take 2 tablets.	0
+2244	oh ic. i thought you meant mary jane.	0
+2245	or just do that 6times	0
+2246	must come later.. i normally bathe him in da afternoon mah..	0
+2247	ugh its been a long day. i'm exhausted. just want to cuddle up and take a nap	0
+2248	congrats ! treat pending.i am not on mail for 2 days.will mail once thru.respect mother at home.check mails.	0
+2249	arms fine, how's cardiff and uni?	0
+2250	don know..wait i will check it.	0
+2251	hey you gave them your photo when you registered for driving ah? tmr wanna meet at yck?	0
+2252	feel like trying kadeem again? :v	0
+2253	remind me how to get there and i shall do so	0
+2254	please attend the phone:)	0
+2255	lol your always so convincing.	0
+2256	mila, age23, blonde, new in uk. i look sex with uk guys. if u like fun with me. text mtalk to 69866.18 . 30pp/txt 1st 5free. ??1.50 increments. help08718728876	1
+2257	pls send me a comprehensive mail about who i'm paying, when and how much.	0
+2258	&lt;#&gt;  great loxahatchee xmas tree burning update: you can totally see stars here	0
+2259	they just talking thats it de. they wont any other.	0
+2260	freemsg: txt: call to no: 86888 & claim your reward of 3 hours talk time to use from your phone now! subscribe6gbp/mnth inc 3hrs 16 stop?txtstop	1
+2261	hey there! glad u r better now. i hear u treated urself to a digi cam, is it good? we r off at 9pm. have a fab new year, c u in coupla wks!	0
+2262	we'll you pay over like  &lt;#&gt; yrs so its not too difficult	0
+2263	just sent you an email ??? to an address with incomm in it, is that right?	0
+2264	watch lor. i saw a few swatch one i thk quite ok. ard 116 but i need 2nd opinion leh...	0
+2265	aiyar dun disturb u liao... thk u have lots 2 do aft ur cupboard come...	0
+2266	lol no ouch but wish i'd stayed out a bit longer	0
+2267	\its ur luck to love someone. its ur fortune to love the one who loves u. but	0
+2268	double your mins & txts on orange or 1/2 price linerental - motorola and sonyericsson with b/tooth free-nokia free call mobileupd8 on 08000839402 or2optout/hv9d	1
+2269	ha ha nan yalrigu heltini..iyo kothi chikku, u shared many things wit me..so far i didn't told any body and even uttered a word abt u.. if ur trusting me so much how can i tell these to others.. plz nxt time dont use those words to me..ok, chikku:-);-)b-)	0
+2270	i am in bus on the way to calicut	0
+2271	cause i'm not freaky lol	0
+2272	i think i am disturbing her da	0
+2273	nah dub but je still buff	0
+2274	this pay is  &lt;decimal&gt;  lakhs:)	0
+2275	sure, i'll see if i can come by in a bit	0
+2276	am in film ill call you later.	0
+2277	discussed with your mother ah?	0
+2278	oooh bed ridden ey? what are you thinking of?	0
+2279	i dunno they close oredi not... ?? v ma fan...	0
+2280	ok... u enjoy ur shows...	0
+2281	vikky, come around  &lt;time&gt; ..	0
+2282	if you don't, your prize will go to another customer. t&c at www.t-c.biz 18+ 150p/min polo ltd suite 373 london w1j 6hl please call back if busy	1
+2283	i dont have that much image in class.	0
+2284	no. yes please. been swimming?	0
+2285	princess, is your kitty shaved or natural?	0
+2286	was actually about to send you a reminder today. have a wonderful weekend	0
+2287	haven't heard anything and he's not answering my texts so i'm guessing he flaked. that said the jb is fantastic	0
+2288	argh my 3g is spotty, anyway the only thing i remember from the research we did was that province and sterling were the only problem-free places we looked at	0
+2289	<forwarded from 21870000>hi - this is your mailbox messaging sms alert. you have 4 messages. you have 21 matches. please call back on 09056242159 to retrieve your messages and matches	1
+2290	no dear i was sleeping :-p	0
+2291	my tuition is at 330. hm we go for the 1120 to 1205 one? do you mind?	0
+2292	kaiez... enjoy ur tuition... gee... thk e second option sounds beta... i'll go yan jiu den msg u...	0
+2293	lol that would be awesome payback.	0
+2294	\ah poor baby!hope urfeeling bettersn luv! probthat overdose of work hey go careful spk 2 u sn lots of lovejen xxx.\""	0
+2295	oh :-)only 4 outside players allowed to play know	0
+2296	soon you will have the real thing princess! do i make you wet? :)	0
+2297	ok. c u then.	0
+2298	u ned to convince him tht its not possible witot hurting his feeling its the main	0
+2299	carlos took a while (again), we leave in a minute	0
+2300	ok...	0
+2301	i'm stuck in da middle of da row on da right hand side of da lt...	0
+2302	get the official england poly ringtone or colour flag on yer mobile for tonights game! text tone or flag to 84199. optout txt eng stop box39822 w111wx ??1.50	1
+2303	ok i found dis pierre cardin one which looks normal costs 20 its on sale.	0
+2304	yes its possible but dint try. pls dont tell to any one k	0
+2305	i am going to film 2day da. at 6pm. sorry da.	0
+2306	for real tho this sucks. i can't even cook my whole electricity is out. and i'm hungry.	0
+2307	please give it 2  or i will pick it up on tuesday evening about 8 if that is ok.	0
+2308	you do your studies alone without anyones help. if you cant no need to study.	0
+2309	dear voucher holder 2 claim your 1st class airport lounge passes when using your holiday voucher call 08704439680. when booking quote 1st class x 2	1
+2310	i sent them. do you like?	0
+2311	i anything lor...	0
+2312	oh is it! which brand?	0
+2313	total video converter free download type this in google search:)	0
+2314	im late tellmiss im on my way	0
+2315	yup, leaving right now, be back soon	0
+2316	ur awarded a city break and could win a ??200 summer shopping spree every wk. txt store to 88039.skilgme.tscs087147403231winawk!age16+??1.50perwksub	1
+2317	same to u...	0
+2318	oh howda gud gud.. mathe en samachara chikku:-)	0
+2319	hmph. go head, big baller.	0
+2320	great! i have to run now so ttyl!	0
+2321	hi, mobile no.  &lt;#&gt;  has added you in their contact list on www.fullonsms.com it s a great place to send free sms to people for more visit fullonsms.com	0
+2322	are you free now?can i call now?	0
+2323	idk. i'm sitting here in a stop and shop parking lot right now bawling my eyes out because i feel like i'm a failure in everything. nobody wants me and now i feel like i'm failing you.	0
+2324	dear 0776xxxxxxx u've been invited to xchat. this is our final attempt to contact u! txt chat to 86688 150p/msgrcvdhg/suite342/2lands/row/w1j6hl ldn 18yrs	1
+2325	jus telling u dat i'll b leaving 4 shanghai on 21st instead so we'll haf more time 2 meet up cya...	0
+2326	88800 and 89034 are premium phone services call 08718711108	1
+2327	by the way, make sure u get train to worc foregate street not shrub hill. have fun night x	0
+2328	please leave this topic..sorry for telling that..	0
+2329	good evening! this is roger. how are you?	0
+2330	yo, any way we could pick something up tonight?	0
+2331	double mins & double txt & 1/2 price linerental on latest orange bluetooth mobiles. call mobileupd8 for the very latest offers. 08000839402 or call2optout/lf56	1
+2332	love you aathi..love u lot..	0
+2333	desires- u going to doctor 4 liver. and get a bit stylish. get ur hair managed. thats it.	0
+2334	hello baby, did you get back to your mom's ? are you setting up the computer now ? filling your belly ? how goes it loverboy ? i miss you already ... *sighs*	0
+2335	what was she looking for?	0
+2336	by monday next week. give me the full gist	0
+2337	dear good morning now only i am up	0
+2338	you getting back any time soon?	0
+2339	customer service announcement. we recently tried to make a delivery to you but were unable to do so, please call 07090298926 to re-schedule. ref:9307622	1
+2340	you have come into my life and brought the sun ..shiny down on me, warming my heart. putting a constant smile on my face ... making me feel loved and cared for	0
+2341	waiting 4 my tv show 2 start lor... u leh still busy doing ur report?	0
+2342	santa calling! would your little ones like a call from santa xmas eve? call 09077818151 to book you time. calls1.50ppm last 3mins 30s t&c www.santacalling.com	1
+2343	i'm in office now . i will call you  &lt;#&gt;  min:)	0
+2344	our brand new mobile music service is now live. the free music player will arrive shortly. just install on your phone to browse content from the top artists.	1
+2345	i want to be inside you every night...	0
+2346	ok i am on the way to home hi hi	0
+2347	play w computer? aiyah i tok 2 u lor?	0
+2348	si si. i think ill go make those oreo truffles.	0
+2349	i like dis sweater fr mango but no more my size already so irritating.	0
+2350	you are guaranteed the latest nokia phone, a 40gb ipod mp3 player or a ??500 prize! txt word: collect to no: 83355! ibhltd ldnw15h 150p/mtmsgrcvd18+	1
+2351	ok.ok ok..then..whats ur todays plan	0
+2352	i dont understand your message.	0
+2353	thank you, winner notified by sms. good luck! no future marketing reply stop to 84122 customer services 08450542832	1
+2354	i had been hoping i would not have to send you this message. my rent is due and i dont have enough for it. my reserves are completely gone. its a loan i need and was hoping you could her. the balance is  &lt;#&gt; . is there a way i could get that from you, till mid march when i hope to pay back.	0
+2355	finished class where are you.	0
+2356	after my work ah... den 6 plus lor... u workin oso rite... den go orchard lor, no other place to go liao...	0
+2357	god picked up a flower and dippeditinadew, lovingly touched itwhichturnedinto u, and the he gifted tomeandsaid,this friend is 4u	0
+2358	wylie update: my weed dealer carlos went to freedom and had a class with lunsford	0
+2359	whatever, juliana. do whatever you want.	0
+2360	the bus leaves at  &lt;#&gt;	0
+2361	you have 1 new voicemail. please call 08719181513.	1
+2362	welcome to select, an o2 service with added benefits. you can now call our specially trained advisors free from your mobile by dialling 402.	1
+2363	gsoh? good with spam the ladies?u could b a male gigolo? 2 join the uk's fastest growing mens club reply oncall. mjzgroup. 08714342399.2stop reply stop. msg@??1.50rcvd	1
+2364	cds 4u: congratulations ur awarded ??500 of cd gift vouchers or ??125 gift guaranteed & freeentry 2 ??100 wkly draw xt music to 87066 tncs www.ldew.com1win150ppmx3age16	1
+2365	shop till u drop, is it you, either 10k, 5k, ??500 cash or ??100 travel voucher, call now, 09064011000. ntt po box cr01327bt fixedline cost 150ppm mobile vary	1
+2366	that's significant but dont worry.	0
+2367	lol its ok i didn't remember til last nite	0
+2368	enjoy ur life. . good night	0
+2369	anytime...	0
+2370	i am in escape theatre now. . going to watch kavalan in a few minutes	0
+2371	this is ur face test ( 1 2 3 4 5 6 7 8 9  &lt;#&gt;  ) select any number i will tell ur face astrology.... am waiting. quick reply...	0
+2372	if i get there before you after your ten billion calls and texts so help me god	0
+2373	zoe it just hit me 2 im fucking shitin myself il defo try my hardest 2 cum 2morow luv u millions lekdog	0
+2374	i think i???m waiting for the same bus! inform me when you get there, if you ever get there.	0
+2375	u so lousy, run already come back then half dead... hee...	0
+2376	probably gonna be here for a while, see you later tonight &lt;)	0
+2377	goal! arsenal 4 (henry, 7 v liverpool 2 henry scores with a simple shot from 6 yards from a pass by bergkamp to give arsenal a 2 goal margin after 78 mins.	1
+2378	congratulations ore mo owo re wa. enjoy it and i wish you many happy moments to and fro wherever you go	0
+2379	5 free top polyphonic tones call 087018728737, national rate. get a toppoly tune sent every week, just text subpoly to 81618, ??3 per pole. unsub 08718727870.	1
+2380	dun need to use dial up juz open da browser n surf...	0
+2381	well if i'm that desperate i'll just call armand again	0
+2382	wat so late still early mah. or we juz go 4 dinner lor. aiya i dunno...	0
+2383	download as many ringtones as u like no restrictions, 1000s 2 choose. u can even send 2 yr buddys. txt sir to 80082 ??3	1
+2384	don't worry though, i understand how important it is that i be put in my place with a poorly thought out punishment in the face of the worst thing that has ever happened to me. brb gonna go kill myself	0
+2385	just finished. missing you plenty	0
+2386	neshanth..tel me who r u?	0
+2387	garbage bags, eggs, jam, bread, hannaford wheat chex	0
+2388	gent! we are trying to contact you. last weekends draw shows that you won a ??1000 prize guaranteed. call 09064012160. claim code k52. valid 12hrs only. 150ppm	1
+2389	ok i go change also...	0
+2390	shall i send that exe to your mail id.	0
+2391	check with nuerologist.	0
+2392	thk some of em find wtc too far... weiyi not goin... e rest i dunno yet... r ur goin 4 dinner den i might b able to join...	0
+2393	u still havent got urself a jacket ah?	0
+2394	private! your 2004 account statement for 078498****7 shows 786 unredeemed bonus points. to claim call 08719180219 identifier code: 45239 expires 06.05.05	1
+2395	i dont have any of your file in my bag..i was in work when you called me.i 'll tell you if i find anything in my room.	0
+2396	oops, i'll let you know when my roommate's done	0
+2397	welp apparently he retired	0
+2398	there the size of elephant tablets & u shove um up ur ass!!	0
+2399	apps class varaya elaya.	0
+2400	mmmm ... fuck ... not fair ! you know my weaknesses ! *grins* *pushes you to your knee's* *exposes my belly and pulls your head to it* don't forget ... i know yours too *wicked smile*	0
+2401	panasonic & bluetoothhdset free. nokia free. motorola free & doublemins & doubletxt on orange contract. call mobileupd8 on 08000839402 or call 2optout	1
+2402	i sent them. do you like?	0
+2403	i'm in a meeting, call me later at	0
+2404	no problem. how are you doing?	0
+2405	jesus christ bitch i'm trying to give you drugs answer your fucking phone	0
+2406	* was really good to see you the other day dudette, been missing you!	0
+2407	your weekly cool-mob tones are ready to download !this weeks new tones include: 1) crazy frog-axel f>>> 2) akon-lonely>>> 3) black eyed-dont p >>>more info in n	1
+2408	just nw i came to hme da..	0
+2409	wa... u so efficient... gee... thanx...	0
+2410	no no. i will check all rooms befor activities	0
+2411	this is the 2nd time we have tried 2 contact u. u have won the 750 pound prize. 2 claim is easy, call 08718726970 now! only 10p per min. bt-national-rate	1
+2412	i'm turning off my phone. my moms telling everyone i have cancer. and my sister won't stop calling. it hurts to talk. can't put up with it. see u when u get home. love u	0
+2413	he needs to stop going to bed and make with the fucking dealing	0
+2414	wif my family booking tour package.	0
+2415	sorry, i can't help you on this.	0
+2416	dude how do you like the buff wind.	0
+2417	your credits have been topped up for http://www.bubbletext.com your renewal pin is tgxxrz	1
+2418	eerie nokia tones 4u, rply tone title to 8007 eg tone dracula to 8007 titles: ghost, addamsfa, munsters, exorcist, twilight www.getzed.co.uk pobox36504w45wq 150p	1
+2419	you've won tkts to the euro2004 cup final or ??800 cash, to collect call 09058099801 b4190604, pobox 7876150ppm	1
+2420	k.. i yan jiu liao... sat we can go 4 bugis vill one frm 10 to 3 den hop to parco 4 nb. sun can go cine frm 1030 to 2, den hop to orc mrt 4 hip hop at 4...	0
+2421	if you were/are free i can give. otherwise nalla adi entey nattil kittum	0
+2422	when you get free, call me	0
+2423	thanks 4 your continued support your question this week will enter u in2 our draw 4 ??100 cash. name the new us president? txt ans to 80082	1
+2424	after the drug she will be able to eat.	0
+2425	r ?_ comin back for dinner?	0
+2426	whatsup there. dont u want to sleep	0
+2427	dunno lei... i might b eatin wif my frens... if ?_ wan to eat then i wait 4 ?_ lar	0
+2428	are you there in room.	0
+2429	hmv bonus special 500 pounds of genuine hmv vouchers to be won. just answer 4 easy questions. play now! send hmv to 86688 more info:www.100percent-real.com	1
+2430	same, i'm at my great aunts anniversary party in tarpon springs	0
+2431	my stomach has been thru so much trauma i swear i just can't eat. i better lose weight.	0
+2432	surly ill give it to you:-) while coming to review.	0
+2433	gud mrng dear hav a nice day	0
+2434	god picked up a flower and dippeditinadew, lovingly touched itwhichturnedinto u, and the he gifted tomeandsaid,this friend is 4u	0
+2435	i was wondering if it would be okay for you to call uncle john and let him know that things are not the same in nigeria as they r here. that  &lt;#&gt;  dollars is 2years sent and that you know its a strain but i plan to pay back every dime he gives. every dime so for me to expect anything from you is not practical. something like that.	0
+2436	keep yourself safe for me because i need you and i miss you already and i envy everyone that see's you in real life	0
+2437	500 free text msgs. just text ok to 80488 and we'll credit your account	1
+2438	am up to my eyes in philosophy	0
+2439	esplanade lor. where else...	0
+2440	u r a winner u ave been specially selected 2 receive ??1000 cash or a 4* holiday (flights inc) speak to a live operator 2 claim 0871277810710p/min (18 )	1
+2441	dear voucher holder have your next meal on us. use the following link on your pc 2 enjoy a 2 4 1 dining experiencehttp://www.vouch4me.com/etlp/dining.asp	1
+2442	ard 530 like dat lor. we juz meet in mrt station then ?_ dun haf to come out.	0
+2443	fyi i'm taking a quick shower, be at epsilon in like  &lt;#&gt;  min	0
+2444	i'm really sorry i lit your hair on fire	0
+2445	tension ah?what machi?any problem?	0
+2446	i enjoy watching and playing football and basketball. anything outdoors. and you?	0
+2447	howz that persons story	0
+2448	nothing. i meant that once the money enters your account here, the bank will remove its flat rate. someone transfered  &lt;#&gt;  to my account and  &lt;#&gt; dollars got removed. so the banks differ and charges also differ.be sure you trust the 9ja person you are sending account details to cos...	0
+2449	yup... from what i remb... i think should be can book...	0
+2450	no i'm good for the movie, is it ok if i leave in an hourish?	0
+2451	private! your 2003 account statement for shows 800 un-redeemed s.i.m. points. call 08715203685 identifier code:4xx26 expires 13/10/04	1
+2452	december only! had your mobile 11mths+? you are entitled to update to the latest colour camera mobile for free! call the mobile update co free on 08002986906	1
+2453	yes princess! i want to please you every night. your wish is my command...	0
+2454	compliments to you. was away from the system. how your side.	0
+2455	my exam is for february 4. wish you a great day.	0
+2456	he dint tell anything. he is angry on me that why you told to abi.	0
+2457	are there ta jobs available? let me know please cos i really need to start working	0
+2458	hahaha..use your brain dear	0
+2459	bloomberg -message center +447797706009 why wait? apply for your future http://careers. bloomberg.com	1
+2460	you got job in wipro:)you will get every thing in life in 2 or 3 years.	0
+2461	sry da..jst nw only i came to home..	0
+2462	im just wondering what your doing right now?	0
+2463	hello, my love ! how went your day ? are you alright ? i think of you, my sweet and send a jolt to your heart to remind you ... i love you! can you hear it ? i screamed it across the sea for all the world to hear. ahmad al hallaq is loved ! and owned ! *possessive passionate kiss*	0
+2464	block breaker now comes in deluxe format with new features and great graphics from t-mobile. buy for just ??5 by replying get bbdeluxe and take the challenge	1
+2465	love it! i want to flood that pretty pussy with cum...	0
+2466	babe! i fucking love you too !! you know? fuck it was so good to hear your voice. i so need that. i crave it. i can't get enough. i adore you, ahmad *kisses*	0
+2467	sorry da:)i was thought of calling you lot of times:)lil busy.i will call you at noon..	0
+2468	are you plans with your family set in stone ?	0
+2469	i'm watching lotr w my sis dis aft. so u wan 2 meet me 4 dinner at nite a not?	0
+2470	dont worry, 1 day very big lambu ji vl come..til then enjoy batchlor party:-)	0
+2471	he didn't see his shadow. we get an early spring yay	0
+2472	ok	0
+2473	i'm outside islands, head towards hard rock and you'll run into me	0
+2474	thank you baby! i cant wait to taste the real thing...	0
+2475	u coming back 4 dinner rite? dad ask me so i re confirm wif u...	0
+2476	i love ya too but try and budget your money better babe. gary would freak on me if he knew	0
+2477	huh i cant thk of more oredi how many pages do we have?	0
+2478	i.ll get there tomorrow and send it to you	0
+2479	there r many model..sony ericson also der.. &lt;#&gt; ..it luks good bt i forgot modl no	0
+2480	apart from the one i told you about yesterday?	0
+2481	i'm fine. hope you are also	0
+2482	lmao where's your fish memory when i need it?	0
+2483	good afternoon sunshine! how dawns that day ? are we refreshed and happy to be alive? do we breathe in the air and smile ? i think of you, my love ... as always	0
+2484	madam,regret disturbance.might receive a reference check from dlf premarica.kindly be informed.rgds,rakhesh,kerala.	0
+2485	thanx 4 sending me home...	0
+2486	the battery is for mr adewale my uncle. aka egbon	0
+2487	yup, no need. i'll jus wait 4 e rain 2 stop.	0
+2488	send me your resume:-)	0
+2489	and  picking them up from various points	0
+2490	that's cool he'll be here all night, lemme know when you're around	0
+2491	\what are youdoing later? sar xxx\""	0
+2492	sary just need tim in the bollox &it hurt him a lot so he tol me!	0
+2493	s:-)if we have one good partnership going we will take lead:)	0
+2494	ranjith cal drpd deeraj and deepak 5min hold	0
+2495	thanx a lot...	0
+2496	if you're thinking of lifting me one then no.	0
+2497	kallis is ready for bat in 2nd innings	0
+2498	auction round 4. the highest bid is now ??54. next maximum bid is ??71. to bid, send bids e. g. 10 (to bid ??10) to 83383. good luck.	1
+2499	a bloo bloo bloo i'll miss the first bowl	0
+2500	never y lei... i v lazy... got wat? dat day ?_ send me da url cant work one...	0
+2501	its posible dnt live in  &lt;#&gt; century cm frwd n thnk different	0
+2502	alright omw, gotta change my order to a half8th	0
+2503	dude. what's up. how teresa. hope you have been okay. when i didnt hear from these people, i called them and they had received the package since dec  &lt;#&gt; . just thot you'ld like to know. do have a fantastic year and all the best with your reading. plus if you can really really bam first aid for usmle, then your work is done.	0
+2504	good. good job. i like entrepreneurs	0
+2505	i havent lei.. next mon can?	0
+2506	?? log off 4 wat. it's sdryb8i	0
+2507	where at were hungry too	0
+2508	i emailed yifeng my part oredi.. can ?_ get it fr him..	0
+2509	no message..no responce..what happend?	0
+2510	\hey hey werethe monkeespeople say we monkeyaround! howdy gorgeous	0
+2511	hey mate. spoke to the mag people. we???re on.  the is deliver by the end of the month. deliver on the 24th sept. talk later.	0
+2512	todays voda numbers ending 1225 are selected to receive a ??50award. if you have a match please call 08712300220 quoting claim code 3100 standard rates app	1
+2513	don???t give a flying monkeys wot they think and i certainly don???t mind. any friend of mine and all that!	0
+2514	waiting for your call.	0
+2515	pls tell nelson that the bb's are no longer comin. the money i was expecting aint coming	0
+2516	yupz... i've oredi booked slots 4 my weekends liao...	0
+2517	hi 07734396839 ibh customer loyalty offer: the new nokia6600 mobile from only ??10 at txtauction!txt word:start to no:81151 & get yours now!4t&	1
+2518	eat jap done oso aft ur lect wat... ?? got lect at 12 rite...	0
+2519	free entry into our ??250 weekly comp just send the word enter to 84128 now. 18 t&c www.textcomp.com cust care 08712405020.	1
+2520	wishing you a wonderful week.	0
+2521	k sure am in my relatives home. sms me de. pls:-)	0
+2522	house-maid is the murderer, coz the man was murdered on  &lt;#&gt; th january.. as public holiday all govt.instituitions are closed,including post office..understand?	0
+2523	too late. i said i have the website. i didn't i have or dont have the slippers	0
+2524	designation is software developer and may be she get chennai:)	0
+2525	4mths half price orange line rental & latest camera phones 4 free. had your phone 11mths+? call mobilesdirect free on 08000938767 to update now! or2stoptxt t&cs	1
+2526	orange customer, you may now claim your free camera phone upgrade for your loyalty. call now on 0207 153 9996. offer ends 14thmarch. t&c's apply. opt-out availa	1
+2527	yes but i dont care! i need you bad, princess!	0
+2528	k still are you loving me.	0
+2529	what he said is not the matter. my mind saying some other matter is there.	0
+2530	mode men or have you left.	0
+2531	like a personal sized or what	0
+2532	shall i come to get pickle	0
+2533	this is my number by vivek..	0
+2534	no. to be nosy i guess. idk am i over reacting if i'm freaked?	0
+2535	tyler (getting an 8th) has to leave not long after 9, can you get here in like an hour?	0
+2536	goodmorning, today i am late for 1hr.	0
+2537	hey are you angry with me. reply me dr.	0
+2538	prabha..i'm soryda..realy..frm heart i'm sory	0
+2539	i thk ?_ gotta go home by urself. cos i'll b going out shopping 4 my frens present.	0
+2540	aight that'll work, thanks	0
+2541	i told her i had a dr appt next week. she thinks i'm gonna die. i told her its just a check. nothing to be worried about. but she didn't listen.	0
+2542	i to am looking forward to all the sex cuddling.. only two more sleeps	0
+2543	sorry dude. dont know how i forgot. even after dan reminded me. sorry. hope you guys had fun.	0
+2544	i love to give massages. i use lots of baby oil... what is your fave position?	0
+2545	can not use foreign stamps in this country.	0
+2546	please call our customer service representative on 0800 169 6031 between 10am-9pm as you have won a guaranteed ??1000 cash or ??5000 prize!	1
+2547	i'll talk to the others and probably just come early tomorrow then	0
+2548	what u mean u almost done? done wif sleeping? but i tot u going to take a nap.. yup i send her liao so i'm picking her up at ard 4 smth lor..	0
+2549	you have registered sinco as payee. log in at icicibank.com and enter urn  &lt;#&gt;  to confirm. beware of frauds. do not share or disclose urn to anyone.	0
+2550	ill be there on  &lt;#&gt;  ok.	0
+2551	oh k :)why you got job then whats up?	0
+2552	2p per min to call germany 08448350055 from your bt line. just 2p per min. check planettalkinstant.com for info & t's & c's. text stop to opt out	1
+2553	ok then i will come to ur home after half an hour	0
+2554	many more happy returns of the day. i wish you happy birthday.	0
+2555	no. i meant the calculation is the same. that  &lt;#&gt; units at  &lt;#&gt; . this school is really expensive. have you started practicing your accent. because its important. and have you decided if you are doing 4years of dental school or if you'll just do the nmde exam.	0
+2556	(bank of granite issues strong-buy) explosive pick for our members *****up over 300% *********** nasdaq symbol cdgt that is a $5.00 per..	1
+2557	jesus armand really is trying to tell everybody he can find	0
+2558	who's there say hi to our drugdealer	0
+2559	i fetch yun or u fetch?	0
+2560	where are you ? what are you doing ? are yuou working on getting the pc to your mom's ? did you find a spot that it would work ? i need you	0
+2561	you always make things bigger than they are	0
+2562	free nokia or motorola with upto 12mths 1/2price linerental, 500 free x-net mins&100txt/mth free b'tooth*. call mobileupd8 on 08001950382 or call 2optout/d3wv	1
+2563	the last thing i ever wanted to do was hurt you. and i didn't think it would have. you'd laugh, be embarassed, delete the tag and keep going. but as far as i knew, it wasn't even up. the fact that you even felt like i would do it to hurt you shows you really don't know me at all. it was messy wednesday, but it wasn't bad. the problem i have with it is you have the time to clean it, but you choose not to. you skype, you take pictures, you sleep, you want to go out. i don't mind a few things here and there, but when you don't make the bed, when you throw laundry on top of it, when i can't have a friend in the house because i'm embarassed that there's underwear and bras strewn on the bed, pillows on the floor, that's something else. you used to be good about at least making the bed.	0
+2564	have a nice day my dear.	0
+2565	\cheers u tex mecause u werebored! yeah okden hunny r uin wk sat?sound??s likeyour havin gr8fun j! keep updat countinlots of loveme xxxxx.\""	0
+2566	imagine you finally get to sink into that bath after i have put you through your paces, maybe even having you eat me for a while before i left ... but also imagine the feel of that cage on your cock surrounded by the bath water, reminding you always who owns you ... enjoy, my cuck	0
+2567	'wnevr i wana fal in luv vth my books, my bed fals in luv vth me..!'' . yen madodu, nav pretsorginta, nammanna pretsovru important alwa....!!:) gud eveb-).	0
+2568	yo, you gonna still be in stock tomorrow/today? i'm trying to get a dubsack	0
+2569	u wake up already? wat u doing? u picking us up later rite? i'm taking sq825, reaching ard 7 smth 8 like dat. u can check e arrival time. c ya soon...	0
+2570	let me know how it changes in the next 6hrs. it can even be appendix but you are out of that age range. however its not impossible. so just chill and let me know in 6hrs	0
+2571	as a valued customer, i am pleased to advise you that following recent review of your mob no. you are awarded with a ??1500 bonus prize, call 09066368470	1
+2572	you are guaranteed the latest nokia phone, a 40gb ipod mp3 player or a ??500 prize! txt word: collect to no: 83355! ibhltd ldnw15h 150p/mtmsgrcvd18	1
+2573	yup. izzit still raining heavily cos i'm in e mrt i can't c outside.	0
+2574	ahhhh...just woken up!had a bad dream about u tho,so i dont like u right now :) i didnt know anything about comedy night but i guess im up for it.	0
+2575	customer place, i wil cal u sir.	0
+2576	do u konw waht is rael friendship im gving yuo an exmpel: jsut ese tihs msg.. evrey splleing of tihs msg is wrnog.. bt sitll yuo can raed it wihtuot ayn mitsake.. goodnight &amp; have a nice sleep..sweet dreams..	0
+2577	eek that's a lot of time especially since american pie is like 8 minutes long. i can't stop singing it.	0
+2578	no gifts!! you trying to get me to throw myself off a cliff or something?	0
+2579	cramps stopped. going back to sleep	0
+2580	well am officially in a philosophical hole, so if u wanna call am at home ready to be saved!	0
+2581	huh so fast... dat means u havent finished painting?	0
+2582	i like cheap! but i???m happy to splash out on the wine if it makes you feel better..	0
+2583	you will be receiving this week's triple echo ringtone shortly. enjoy it!	1
+2584	rats. hey did u ever vote for the next themes?	0
+2585	ok i juz receive..	0
+2586	i'm thinking that chennai forgot to come for auction..	0
+2587	ur going 2 bahamas! callfreefone 08081560665 and speak to a live operator to claim either bahamas cruise of??2000 cash 18+only. to opt out txt x to 07786200117	1
+2588	nooooooo i'm gonna be bored to death all day. cable and internet outage.	0
+2589	want 2 get laid tonight? want real dogging locations sent direct 2 ur mob? join the uk's largest dogging network bt txting gravel to 69888! nt. ec2a. 31p.msg@150p	1
+2590	wat makes some people dearer is not just de happiness dat u feel when u meet them but de pain u feel when u miss dem!!!	0
+2591	have you finished work yet? :)	0
+2592	then why you came to hostel.	0
+2593	and he's apparently bffs with carly quick now	0
+2594	there are many company. tell me the language.	0
+2595	tells u 2 call 09066358152 to claim ??5000 prize. u have 2 enter all ur mobile & personal details @ the prompts. careful!	1
+2596	happy birthday... may u find ur prince charming soon n dun work too hard...	0
+2597	mind blastin.. no more tsunamis will occur from now on.. rajnikant stopped swimming in indian ocean..:-d	0
+2598	thank you. and by the way, i just lost.	0
+2599	that's the way you should stay oh.	0
+2600	me i'm not workin. once i get job...	0
+2601	meeting u is my work. . . tel me when shall i do my work tomorrow	0
+2602	i hate when she does this. she turns what should be a fun shopping trip into an annoying day of how everything would look in her house.	0
+2603	moby pub quiz.win a ??100 high street prize if u know who the new duchess of cornwall will be? txt her first name to 82277.unsub stop ??1.50 008704050406 sp arrow	1
+2604	tell rob to mack his gf in the theater	0
+2605	smile in pleasure smile in pain smile when trouble pours like rain smile when sum1 hurts u smile becoz someone still loves to see u smiling!!	0
+2606	whatsup there. dont u want to sleep	0
+2607	just haven't decided where yet eh ?	0
+2608	free entry in 2 a weekly comp for a chance to win an ipod. txt pod to 80182 to get entry (std txt rate) t&c's apply 08452810073 for details 18+	1
+2609	wot is u up 2 then bitch?	0
+2610	just normal only here :)	0
+2611	i've got  &lt;#&gt; , any way i could pick up?	0
+2612	you said to me before i went back to bed that you can't sleep for anything.	0
+2613	?? go home liao? ask dad to pick me up at 6...	0
+2614	can u get pic msgs to your phone?	0
+2615	get the door, i'm here	0
+2616	just so that you know,yetunde hasn't sent money yet. i just sent her a text not to bother sending. so its over, you dont have to involve yourself in anything. i shouldn't have imposed anything on you in the first place so for that, i apologise.	0
+2617	i am great princess! what are you thinking about me? :)	0
+2618	also track down any lighters you can find	0
+2619	yeah confirmed for you staying at  that weekend	0
+2620	haven't seen my facebook, huh? lol!	0
+2621	it's justbeen overa week since we broke up and already our brains are going to mush!	0
+2622	nothing, smsing u n xy lor. sorry lor da guys neva c u in person but they sort of know u lor. so u wan 2 meet them xy ask me 2 bring u along 4 our next meeting.	0
+2623	nah i don't think he goes to usf, he lives around here though	0
+2624	ooh, 4got, i'm gonna start belly dancing in moseley weds 6.30 if u want 2 join me, they have a cafe too.	0
+2625	oh k.i think most of wi and nz players unsold.	0
+2626	so that means you still think of teju	0
+2627	whats that coming over the hill..... is it a monster! hope you have a great day. things r going fine here, busy though!	0
+2628	ok lar... joking wif u oni...	0
+2629	haha yeah i see that now, be there in a sec	0
+2630	draw va?i dont think so:)	0
+2631	hey elaine, is today's meeting still on?	0
+2632	k, i'll work something out	0
+2633	k, i might come by tonight then if my class lets out early	0
+2634	good morning, my boytoy! how's those yummy lips ? where's my sexy buns now ? what do you do ? do you think of me ? do you crave me ? do you need me ?	0
+2635	hey whats up? u sleeping all morning?	0
+2636	can you just come in for a sec? there's somebody here i want you to see	0
+2637	ok. every night take a warm bath drink a cup of milk and you'll see a work of magic. you still need to loose weight. just so that you know	0
+2638	i can send you a pic if you like :)	0
+2639	but you dint in touch with me.	0
+2640	not heard from u4 a while. call 4 rude chat private line 01223585334 to cum. wan 2c pics of me gettin shagged then text pix to 8552. 2end send stop 8552 sam xxx	1
+2641	y cant u try new invention to fly..i'm not joking.,	0
+2642	i'm putting it on now. it should be ready for  &lt;time&gt;	0
+2643	7 lor... change 2 suntec... wat time u coming?	0
+2644	u goin out 2nite?	0
+2645	u dun say so early hor... u c already then say...	0
+2646	moon has come to color your dreams, stars to make them musical and my sms to give you warm and peaceful sleep. good night	0
+2647	and also i've sorta blown him off a couple times recently so id rather not text him out of the blue looking for weed	0
+2648	not..tel software name..	0
+2649	why de. you looking good only:-)..	0
+2650	got but got 2 colours lor. one colour is quite light n e other is darker lor. actually i'm done she's styling my hair now.	0
+2651	i know you are serving. i mean what are you doing now.	0
+2652	urgent! we are trying to contact u. todays draw shows that you have won a ??800 prize guaranteed. call 09050003091 from land line. claim c52. valid 12hrs only	1
+2653	ummmmmaah many many happy returns of d day my dear sweet heart.. happy birthday dear	0
+2654	hi there, 2nights ur lucky night! uve been invited 2 xchat, the uks wildest chat! txt chat to 86688 now! 150p/msgrcvdhg/suite342/2lands/row/w1j6hl ldn 18yrs	1
+2655	free entry in 2 a wkly comp to win fa cup final tkts 21st may 2005. text fa to 87121 to receive entry question(std txt rate)t&c's apply 08452810075over18's	1
+2656	you give us back my id proof and  &lt;#&gt;  rs. we wont allow you to work. we will come to your home within days	0
+2657	hi dear call me its urgnt. i don't know whats your problem. you don't want to work or if you have any other problem at least tell me. wating for your reply.	0
+2658	captain vijaykanth is doing comedy in captain tv..he is drunken :)	0
+2659	u still painting ur wall?	0
+2660	want to funk up ur fone with a weekly new tone reply tones2u 2 this text. www.ringtones.co.uk, the original n best. tones 3gbp network operator rates apply	1
+2661	how much did ur hdd casing cost.	0
+2662	*deep sigh* ... i miss you :-( ... i am really surprised you haven't gone to the net cafe yet to get to me ... don't you miss me?	0
+2663	okay lor... will they still let us go a not ah? coz they will not know until later. we drop our cards into the box right?	0
+2664	aight fuck it, i'll get it later	0
+2665	do you know why god created gap between your fingers..? so that, one who is made for you comes &amp; fills those gaps by holding your hand with love..!	0
+2666	me also da, i feel yesterday night  wait til 2day night dear.	0
+2667	my planning usually stops at \find hella weed	0
+2668	would u fuckin believe it they didnt know i had thurs pre booked off so they re cancelled me again! that needs to b sacked	0
+2669	oh rite. well im with my best mate pete, who i went out with 4 a week+ now were 2geva again. its been longer than a week.	0
+2670	if india win or level series means this is record:)	0
+2671	i'm a guy, browsin is compulsory	0
+2672	i thought slide is enough.	0
+2673	sorry me going home first... daddy come fetch ?_ later...	0
+2674	k..k:)how much does it cost?	0
+2675	i dont. can you send it to me. plus how's mode.	0
+2676	hot live fantasies call now 08707509020 just 20p per min ntt ltd, po box 1327 croydon cr9 5wb 0870 is a national rate call	1
+2677	total disappointment, when i texted you was the craziest shit got :(	0
+2678	are you willing to go for apps class.	0
+2679	ok i am on the way to railway	0
+2680	what type of stuff do you sing?	0
+2681	will ?_ b going to esplanade fr home?	0
+2682	dunno lei he neva say...	0
+2683	living is very simple.. loving is also simple.. laughing is too simple.. winning is tooo simple.. but, being 'simple' is very difficult.. gud nte.:-	0
+2684	pls ask macho how much is budget for bb bold 2 is cos i saw a new one for  &lt;#&gt;  dollars.	0
+2685	sorry sir, i will call you tomorrow.  senthil.hsbc	0
+2686	what your plan for pongal?	0
+2687	urgent! you have won a 1 week free membership in our ??100,000 prize jackpot! txt the word: claim to no: 81010 t&c www.dbuk.net lccltd pobox 4403ldnw1a7rw18	1
+2688	need a coffee run tomo?can't believe it's that time of week already	0
+2689	the greatest test of courage on earth is to bear defeat without losing heart....gn tc	0
+2690	no did you check? i got his detailed message now	0
+2691	free2day sexy st george's day pic of jordan!txt pic to 89080 dont miss out, then every wk a saucy celeb!4 more pics c pocketbabe.co.uk 0870241182716 ??3/wk	1
+2692	pls she needs to dat slowly or she will vomit more.	0
+2693	every day i use to sleep after  &lt;#&gt;  so only.	0
+2694	jay says he'll put in  &lt;#&gt;	0
+2695	ah, well that confuses things, doesn???t it?	0
+2696	\alright babe	0
+2697	ha ha - had popped down to the loo when you hello-ed me. hello!	0
+2698	was the farm open?	0
+2699	ok lor thanx... ?? in school?	0
+2700	sorry brah, just finished the last of my exams, what up	0
+2701	guess which pub im in? im as happy as a pig in clover or whatever the saying is!	0
+2702	am not interested to do like that.	0
+2703	ok... i din get ur msg...	0
+2704	life is more strict than teacher... bcoz teacher teaches lesson &amp; then conducts exam, but life first conducts exam &amp; then teaches lessons. happy morning. . .	0
+2705	i will see in half an hour	0
+2706	u too...	0
+2707	tomarrow final hearing on my laptop case so i cant.	0
+2708	i sent my scores to sophas and i had to do secondary application for a few schools. i think if you are thinking of applying, do a research on cost also. contact joke ogunrinde, her school is one me the less expensive ones	0
+2709	also where's the piece	0
+2710	it just seems like weird timing that the night that all you and g want is for me to come smoke is the same day as when a shitstorm is attributed to me always coming over and making everyone smoke	0
+2711	yes we were outside for like 2 hours. and i called my whole family to wake them up cause it started at 1 am	0
+2712	how is it possible to teach you. and where.	0
+2713	from 88066 lost ??12 help	1
+2714	sorry about that this is my mates phone and i didnt write it love kate	0
+2715	uhhhhrmm isnt having tb test bad when youre sick	0
+2716	dear got train and seat mine lower seat	0
+2717	mm that time you dont like fun	0
+2718	did u see what i posted on your facebook?	0
+2719	ok lor. i ned 2 go toa payoh 4 a while 2 return smth u wan 2 send me there or wat?	0
+2720	, im .. on the snowboarding trip. i was wondering if your planning to get everyone together befor we go..a meet and greet kind of affair? cheers,	0
+2721	just sleeping..and surfing	0
+2722	s'fine. anytime. all the best with it.	0
+2723	think you sent the text to the home phone. that cant display texts. if you still want to send it his number is	0
+2724	yeah, probably but not sure. ilol let u know, but personally i wuldnt bother, then again if ur goin to then i mite as well!!	0
+2725	money i have won wining number 946 wot do i do next	1
+2726	on the way to office da..	0
+2727	lmao!nice 1	0
+2728	sorry i'm not free...	0
+2729	call me when you/carlos is/are here, my phone's vibrate is acting up and i might not hear texts	0
+2730	can you tell shola to please go to college of medicine and visit the academic department, tell the academic secretary what the current situation is and ask if she can transfer there. she should ask someone to check sagamu for the same thing and lautech. its vital she completes her medical education in nigeria. its less expensive much less expensive. unless she will be getting citizen rates in new zealand.	0
+2731	where r we meeting?	0
+2732	theyre doing it to lots of places. only hospitals and medical places are safe.	0
+2733	hey now am free you can call me.	0
+2734	yeah imma come over cause jay wants to do some drugs	0
+2735	free-message: jamster!get the crazy frog sound now! for poly text mad1, for real text mad2 to 88888. 6 crazy sounds for just 3 gbp/week! 16+only! t&c's apply	1
+2736	sms services. for your inclusive text credits, pls goto www.comuk.net login= ***** unsubscribe with stop. no extra charge. help:08700469649. po box420. ip4 5we	1
+2737	\keep ur problems in ur heart	0
+2738	am surfing online store. for offers do you want to buy any thing.	0
+2739	\gimme a few\" was  &lt;#&gt;  minutes ago"	0
+2740	for fear of fainting with the of all that housework you just did? quick have a cuppa	0
+2741	daddy, shu shu is looking 4 u... u wan me 2 tell him u're not in singapore or wat?	0
+2742	got ur mail dileep.thank you so muchand look forward to lots of support...very less contacts here,remember one venugopal you mentioned.tomorrow if not late,i shall try to come up till there.goodnight dear.	0
+2743	i thank you so much for all you do with selflessness. i love you plenty.	0
+2744	bugis oso near wat...	0
+2745	i have to take exam with march 3	0
+2746	maybe if you woke up before fucking 3 this wouldn't be a problem.	0
+2747	yes i started to send requests to make it but pain came back so i'm back in bed. double coins at the factory too. i gotta cash in all my nitros.	0
+2748	yo, you around? just got my car back	0
+2749	aathi..where are you dear..	0
+2750	hi. hope you had a good day. have a better night.	0
+2751	dude we should go sup again	0
+2752	is she replying. has boye changed his phone number	0
+2753	boo what time u get out? u were supposed to take me shopping today. :(	0
+2754	and how's your husband.	0
+2755	is there coming friday is leave for pongal?do you get any news from your work place.	0
+2756	doc prescribed me morphine cause the other pain meds aren't enough. waiting for my mom to bring it. that med should kick in fast so i'm gonna try to be on later	0
+2757	thanks for ve lovely wisheds. you rock	0
+2758	hi my email address has changed now it is	0
+2759	come to me right now, ahmad	0
+2760	house-maid is the murderer, coz the man was murdered on  &lt;#&gt; th january.. as public holiday all govt.instituitions are closed,including post office..understand?	0
+2761	hi! you just spoke to maneesha v. we'd like to know if you were satisfied with the experience. reply toll free with yes or no.	0
+2762	ya, told..she was asking wats matter?	0
+2763	i dont thnk its a wrong calling between us	0
+2764	hai dear friends... this is my new &amp; present number..:) by rajitha raj (ranju)	0
+2765	thanks for being there for me just to talk to on saturday. you are very dear to me. i cherish having you as a brother and role model.	0
+2766	oh, then your phone phoned me but it disconnected	0
+2767	we know someone who you know that fancies you. call 09058097218 to find out who. pobox 6, ls15hb 150p	1
+2768	prof: you have passed in all the papers in this sem congrats . . . . student: enna kalaachutaarama..!! prof:???? gud mrng!	0
+2769	ok... but they said i've got wisdom teeth hidden inside n mayb need 2 remove.	0
+2770	good stuff, will do.	0
+2771	jus finish blowing my hair. u finish dinner already?	0
+2772	still i have not checked it da. . .	0
+2773	i'll text you when i drop x off	0
+2774	yeah, that's what i was thinking	0
+2775	sorry completely forgot * will pop em round this week if your still here?	0
+2776	you have won a guaranteed 32000 award or maybe even ??1000 cash to claim ur award call free on 0800 ..... (18+). its a legitimat efreefone number wat do u think???	1
+2777	how come it takes so little time for a child who is afraid of the dark to become a teenager who wants to stay out all night?	1
+2778	i'm working technical support :)voice process.networking field.	0
+2779	freemsg: our records indicate you may be entitled to 3750 pounds for the accident you had. to claim for free reply with yes to this msg. to opt out text stop	1
+2780	are you the cutest girl in the world or what	0
+2781	k, if u bored up just come to my home..	0
+2782	the word \checkmate\" in chess comes from the persian phrase \"shah maat\" which means; \"the king is dead..\" goodmorning.. have a good day..:)"	0
+2783	horrible bf... i now v hungry...	0
+2784	don't look back at the building because you have no coat and i don't want you to get more sick. just hurry home and wear a coat to the gym!!!	0
+2785	back in brum! thanks for putting us up and keeping us all and happy. see you soon	0
+2786	yo guess what i just dropped	0
+2787	no! but we found a diff farm shop to buy some cheese. on way back now, can i call in?	0
+2788	\getting tickets 4 walsall tue 6 th march. my mate is getting me them on sat. ill pay my treat. want 2 go. txt bak .terry\""	0
+2789	where u been hiding stranger?	0
+2790	freemsg>fav xmas tones!reply real	1
+2791	then anything special?	0
+2792	you can stop further club tones by replying \stop mix\" see my-tone.com/enjoy. html for terms. club tones cost gbp4.50/week. mfl	1
+2793	i've sent my wife your text. after we buy them she'll tell you what to do. so just relax. we should go get them this wkend.	0
+2794	lol wtf random. btw is that your lunch break	0
+2795	i'll be at mu in like  &lt;#&gt;  seconds	0
+2796	and miss vday the parachute and double coins??? u must not know me very well...	0
+2797	i'm in a meeting, call me later at	0
+2798	then dun wear jeans lor...	0
+2799	hey sexy buns ! have i told you ? i adore you, loverboy. i hope you remember to thank your sister in law for those meatballs *grins* ... i love you, babe	0
+2800	mmm ... fuck .... merry christmas to me	0
+2801	dont talk to him ever ok its my word.	0
+2802	lol u still feeling sick?	0
+2803	urgent! your mobile number has been awarded with a ??2000 prize guaranteed. call 09061790121 from land line. claim 3030. valid 12hrs only 150ppm	1
+2804	kind of. took it to garage. centre part of exhaust needs replacing. part ordered n taking it to be fixed tomo morning.	0
+2805	i am on the way to tirupur.	0
+2806	i've reached sch already...	0
+2807	hey... are you going to quit soon? xuhui and i working till end of the month	0
+2808	reminder from o2: to get 2.50 pounds free call credit and details of great offers pls reply 2 this text with your valid name, house no and postcode	1
+2809	hope this text meets you smiling. if not then let this text give you a reason to smile. have a beautiful day.	0
+2810	go where n buy? juz buy when we get there lar.	0
+2811	you have won a guaranteed ??200 award or even ??1000 cashto claim ur award call free on 08000407165 (18+) 2 stop getstop on 88222 php. rg21 4jx	1
+2812	you bad girl. i can still remember them	0
+2813	private! your 2003 account statement for shows 800 un-redeemed s. i. m. points. call 08715203652 identifier code: 42810 expires 29/10/0	1
+2814	fuuuuck i need to stop sleepin, sup	0
+2815	rose for red,red for blood,blood for heart,heart for u. but u for me.... send tis to all ur friends.. including me.. if u like me.. if u get back, 1-u r poor in relation! 2-u need some 1 to support 3-u r frnd 2 many 4-some1 luvs u 5+- some1 is praying god to marry u.:-) try it....	0
+2816	oh, i will get paid. the most outstanding one is for a commercial i did for hasbro...in august! they made us jump through so many hoops to get paid. still not.	0
+2817	let me know how to contact you. i've you settled in a room. lets know you are ok.	0
+2818	wow didn't think it was that common. i take it all back ur not a freak! unless u chop it off:-)	0
+2819	yep. i do like the pink furniture tho.	0
+2820	look at amy ure a beautiful, intelligent woman and i like u a lot. i know u don??t like me like that so don??t worry.	0
+2821	excellent. i spent  &lt;#&gt;  years in the air force. iraq and afghanistan. i am stable and honest. do you like traveling?	0
+2822	free top ringtone -sub to weekly ringtone-get 1st week free-send subpoly to 81618-?3 per week-stop sms-08718727870	1
+2823	hi darlin its kate are u up for doin somethin tonight? im going to a pub called the swan or something with my parents for one drink so phone me if u can	0
+2824	if i let you do this, i want you in the house by 8am.	0
+2825	yup...	0
+2826	come back to tampa ffffuuuuuuu	0
+2827	weightloss! no more girl friends. make loads of money on ebay or something. and give thanks to god.	0
+2828	we currently have a message awaiting your collection. to collect your message just call 08718723815.	1
+2829	if you text on your way to cup stop that should work. and that should be bus	0
+2830	yup i shd haf ard 10 pages if i add figures... ?? all got how many pages?	0
+2831	eat at old airport road... but now 630 oredi... got a lot of pple...	0
+2832	dunno dat's wat he told me. ok lor...	0
+2833	text me when you get off, don't call, my phones having problems	0
+2834	2 and half years i missed your friendship:-)	0
+2835	sorry, i'll call later	0
+2836	the affidavit says  &lt;#&gt;  e twiggs st, division g, courtroom  &lt;#&gt; , &lt;time&gt;  am. i'll double check and text you again tomorrow	0
+2837	i dunno lei... like dun haf...	0
+2838	yes, princess. toledo.	0
+2839	on hen night. going with a swing	0
+2840	not much no fights. it was a good nite!!	0
+2841	i.ll give her once i have it. plus she said grinule greet you whenever we speak	0
+2842	romcapspam everyone around should be responding well to your presence since you are so warm and outgoing. you are bringing in a real breath of sunshine.	1
+2843	you see the requirements please	0
+2844	ok. there may be a free gym about.	0
+2845	your right! i'll make the appointment right now.	0
+2846	send me yetty's number pls.	0
+2847	have a good trip. watch out for . remember when you get back we must decide about easter.	0
+2848	shopping lor. them raining mah hard 2 leave orchard.	0
+2849	mm feeling sleepy. today itself i shall get that dear	0
+2850	all the lastest from stereophonics, marley, dizzee racal, libertines and the strokes! win nookii games with flirt!! click themob wap bookmark or text wap to 82468	1
+2851	you have won ?1,000 cash or a ?2,000 prize! to claim, call09050000327	1
+2852	thats cool. i want to please you...	0
+2853	horrible u eat macs eat until u forgot abt me already rite... u take so long 2 reply. i thk it's more toot than b4 so b prepared. now wat shall i eat?	0
+2854	what can i do? might accidant tookplace between somewhere ghodbandar rd. traffic moves slovely. so plz slip &amp; don't worry.	0
+2855	they said ?_ dun haf passport or smth like dat.. or ?_ juz send to my email account..	0
+2856	not yet chikku..k, then wat abt tht guy did he stopped irritating or msging to u..	0
+2857	y dun cut too short leh. u dun like ah? she failed. she's quite sad.	0
+2858	then what about further plan?	0
+2859	no i am not having not any movies in my laptop	0
+2860	am only searching for good dual sim mobile pa.	0
+2861	that's very rude, you on campus?	0
+2862	get ur 1st ringtone free now! reply to this msg with tone. gr8 top 20 tones to your phone every week just ??1.50 per wk 2 opt out send stop 08452810071 16	1
+2863	my friend just got here and says he's upping his order by a few grams (he's got $ &lt;#&gt; ), when can you get here?	0
+2864	8 at the latest, g's still there if you can scrounge up some ammo and want to give the new ak a try	0
+2865	i have a date on sunday with will!!	0
+2866	i forgot 2 ask ?_ all smth.. there's a card on da present lei... how? ?? all want 2 write smth or sign on it?	0
+2867	tbs/persolvo. been chasing us since sept for??38 definitely not paying now thanks to your information. we will ignore them. kath. manchester.	1
+2868	oh yes, why is it like torture watching england?	0
+2869	i am on the way to ur home	0
+2870	i don't know jack shit about anything or i'd say/ask something helpful but if you want you can pretend that i did and just text me whatever in response to the hypotheticalhuagauahahuagahyuhagga	0
+2871	k da:)how many page you want?	0
+2872	he is there. you call and meet him	0
+2873	oh oh... wasted... den muz chiong on sat n sun liao...	0
+2874	ok... thanx... gd nite 2 ?_ too...	0
+2875	and that's fine, i got enough bud to last most of the night at least	0
+2876	all done, all handed in. don't know if mega shop in asda counts as celebration but thats what i'm doing!	0
+2877	i am not at all happy with what you saying or doing	0
+2878	sorry, i'll call later	0
+2879	if u dun drive then how i go 2 sch.	0
+2880	kit strip - you have been billed 150p. netcollex ltd. po box 1013 ig11 oja	1
+2881	watching tv lor... y she so funny we bluff her 4 wat. izzit because she thk it's impossible between us?	0
+2882	k. i will sent it again	0
+2883	well i will watch shrek in 3d!!b)	0
+2884	i'm reaching home in 5 min.	0
+2885	congratulations you've won. you're a winner in our august ??1000 prize draw. call 09066660100 now. prize code 2309.	1
+2886	are you at work right now ?	0
+2887	what's the significance?	0
+2888	hi baby ive just got back from work and i was wanting to see u allday! i hope i didnt piss u off on the phone today. if u are up give me a call xxx	0
+2889	multiply the numbers independently and count decimal points then, for the division, push the decimal places like i showed you.	0
+2890	dear,me at cherthala.in case u r coming cochin pls call bfore u start.i shall also reach accordingly.or tell me which day u r coming.tmorow i am engaged ans its holiday.	0
+2891	sorry to trouble u again. can buy 4d for my dad again? 1405, 1680, 1843. all 2 big 1 small, sat n sun. thanx.	0
+2892	get your garden ready for summer with a free selection of summer bulbs and seeds worth ??33:50 only with the scotsman this saturday. to stop go2 notxt.co.uk	1
+2893	the beauty of life is in next second.. which hides thousands of secrets. i wish every second will be wonderful in ur life...!! gud n8	0
+2894	we can go 4 e normal pilates after our intro...	0
+2895	* you gonna ring this weekend or wot?	0
+2896	you've already got a flaky parent. it'snot supposed to be the child's job to support the parent...not until they're the ride age anyway. i'm supposed to be there to support you. and now i've hurt you. unintentional. but hurt nonetheless.	0
+2897	i guess it is useless calling u 4 something important.	0
+2898	gr8. so how do you handle the victoria island traffic. plus when's the album due	0
+2899	&lt;#&gt;  am i think? should say on syllabus	0
+2900	oh, the grand is having a bit of a party but it doesn't mention any cover charge so it's probably first come first served	0
+2901	yo you around? a friend of mine's lookin to pick up later tonight	0
+2902	shhhhh nobody is supposed to know!	0
+2903	?? thk of wat to eat tonight.	0
+2904	lol, oh you got a friend for the dog ?	0
+2905	i notice you like looking in the shit mirror youre turning into a right freak	0
+2906	i want to show you the world, princess :) how about europe?	0
+2907	finally the match heading towards draw as your prediction.	0
+2908	going for dinner.msg you after.	0
+2909	it took mr owl 3 licks	0
+2910	i know that my friend already told that.	0
+2911	she is our sister.. she belongs 2 our family.. she is d hope of tomorrow.. pray 4 her,who was fated 4 d shoranur train incident. lets hold our hands together &amp; fuelled by love &amp; concern prior 2 her grief &amp; pain. pls join in dis chain &amp; pass it. stop violence against women.	0
+2912	greetings me, ! consider yourself excused.	0
+2913	ola would get back to you maybe not today but i ve told him you can be his direct link in the us in getting cars he bids for online, you arrange shipping and you get a cut. or u????? for a partnership where u????? invest money for shipping and he takes care of the rest!u??wud b self reliant soon dnt worry	0
+2914	k k:) sms chat with me.	0
+2915	k i'm leaving soon, be there a little after 9	0
+2916	you ve won! your 4* costa del sol holiday or ??5000 await collection. call 09050090044 now toclaim. sae, tc s, pobox334, stockport, sk38xh, cost??1.50/pm, max10mins	1
+2917	if you are not coughing then its nothing	0
+2918	ok darlin i supose it was ok i just worry too much.i have to do some film stuff my mate and then have to babysit again! but you can call me there.xx	0
+2919	think i might have to give it a miss. am teaching til twelve, then have lecture at two. damn this working thing.	0
+2920	urgent! your mobile number has been awarded a <ukp>2000 prize guaranteed. call 09061790125 from landline. claim 3030. valid 12hrs only 150ppm	1
+2921	is it ok if i stay the night here? xavier has a sleeping bag and i'm getting tired	0
+2922	tessy..pls do me a favor. pls convey my birthday wishes to nimya..pls dnt forget it. today is her birthday shijas	0
+2923	so how's scotland. hope you are not over showing your jjc tendencies. take care. live the dream	0
+2924	the new deus ex game comin early next yr	0
+2925	well i know z will take care of me. so no worries.	0
+2926	umma my life and vava umma love you lot dear	0
+2927	lol i have to take it. member how i said my aunt flow didn't visit for 6 months? it's cause i developed ovarian cysts. bc is the only way to shrink them.	0
+2928	sunshine quiz! win a super sony dvd recorder if you canname the capital of australia? text mquiz to 82277. b	1
+2929	btw regarding that we should really try to see if anyone else can be our 4th guy before we commit to a random dude	0
+2930	kothi print out marandratha.	0
+2931	edison has rightly said, \a fool can ask more questions than a wise man can answer\" now you know why all of us are speechless during viva.. gm ge gnt:-)"	0
+2932	i was up all night too worrying about this appt. it's a shame we missed a girls night out with quizzes popcorn and you doing my hair.	0
+2933	hi baby im cruisin with my girl friend what r u up 2? give me a call in and hour at home if thats alright or fone me on this fone now love jenny xxx	0
+2934	u wake up already? thanx 4 e tau sar piah it's quite nice.	0
+2935	double mins and txts 4 6months free bluetooth on orange. available on sony, nokia motorola phones. call mobileupd8 on 08000839402 or call2optout/n9dx	1
+2936	hi babe its chloe, how r u? i was smashed on saturday night, it was great! how was your weekend? u been missing me? sp visionsms.com text stop to stop 150p/text	1
+2937	you said not now. no problem. when you can. let me know.	0
+2938	you were supposed to wake me up &gt;:(	0
+2939	i can take you at like noon	0
+2940	2 celebrate my b??day, y else?	0
+2941	i anything lor.	0
+2942	ok set let u noe e details later...	0
+2943	i'm hungry buy smth home...	0
+2944	k..u also dont msg or reply to his msg..	0
+2945	it???s reassuring, in this crazy world.	0
+2946	miss call miss call khelate kintu opponenter miss call dhorte lage. thats d rule. one with great phone receiving quality wins.	0
+2947	sleeping nt feeling well	0
+2948	aight, we'll head out in a few	0
+2949	aiyah sorry lor... i watch tv watch until i forgot 2 check my phone.	0
+2950	oh god. i'm gonna google nearby cliffs now.	0
+2951	i'm not driving... raining! then i'll get caught at e mrt station lor.	0
+2952	long beach lor. expected... u having dinner now?	0
+2953	i wont get concentration dear you know you are my mind and everything :-)	0
+2954	no problem with the renewal. i.ll do it right away but i dont know his details.	0
+2955	ok i wont call or disturb any one. i know all are avoiding me. i am a burden for all	0
+2956	thanks for your ringtone order, reference t91. you will be charged gbp 4 per week. you can unsubscribe at anytime by calling customer services on 09057039994	1
+2957	havent stuck at orchard in my dad's car. going 4 dinner now. u leh? so r they free tonight?	0
+2958	no de. but call me after some time. ill tell you k	0
+2959	i take it the post has come then! you must have 1000s of texts now! happy reading. my one from wiv hello caroline at the end is my favourite. bless him	0
+2960	i want to send something that can sell fast.  &lt;#&gt; k is not easy money.	0
+2961	urgent, important information for o2 user. today is your lucky day! 2 find out why log onto http://www.urawinner.com there is a fantastic surprise awaiting for you	1
+2962	uh, heads up we don't have that much left	0
+2963	can you please ask macho what his price range is, does he want something new or used plus it he only interfued in the blackberry bold  &lt;#&gt;  or any bb	0
+2964	text82228>> get more ringtones, logos and games from www.txt82228.com. questions: info@txt82228.co.uk	1
+2965	okie	0
+2966	hi darlin im on helens fone im gonna b up the princes 2 nite please come up tb love kate	0
+2967	yes.i'm in office da:)	0
+2968	7 wonders in my world 7th you 6th ur style 5th ur smile 4th ur personality 3rd ur nature 2nd ur sms and 1st \ur lovely friendship\"... good morning dear"	0
+2969	good words.... but words may leave u in dismay many times.	0
+2970	u gd lor go shopping i got stuff to do. u wan 2 watch infernal affairs a not? come lar...	0
+2971	do you work all this week ?	0
+2972	where's my boytoy? i miss you ... what happened?	0
+2973	you have 1 new message. please call 08718738034.	1
+2974	nah can't help you there, i've never had an iphone	0
+2975	k..k..i'm also fine:)when will you complete the course?	0
+2976	so i asked how's anthony. dad. and your bf	0
+2977	i am taking half day leave bec i am not well	0
+2978	congrats 2 mobile 3g videophones r yours. call 09063458130 now! videochat wid ur mates, play java games, dload polyph music, noline rentl. bx420. ip4. 5we. 150p	1
+2979	yeh. indians was nice. tho it did kane me off a bit he he. we shud go out 4 a drink sometime soon. mite hav 2 go 2 da works 4 a laugh soon. love pete x x	0
+2980	allo! we have braved the buses and taken on the trains and triumphed. i mean we???re in b???ham. have a jolly good rest of week	0
+2981	yes, my reg is  ciao!	0
+2982	sir, i am waiting for your call, once free please call me.	0
+2983	your opinion about me? 1. over 2. jada 3. kusruthi 4. lovable 5. silent 6. spl character 7. not matured 8. stylish 9. simple pls reply..	0
+2984	haha, that was the first person i was gonna ask	0
+2985	you tell what happen dont behave like this to me. ok no need to say	0
+2986	who were those people ? were you in a tour ? i thought you were doing that sofa thing you sent me ? your curious sugar	0
+2987	yar but they say got some error.	0
+2988	just curious because my cuz asked what i was up to	0
+2989	i might come to kerala for 2 days.so you can be prepared to take a leave once i finalise .dont plan any travel during my visit.need to finish urgent works.	0
+2990	its normally hot mail. com you see!	0
+2991	actually getting ready to leave the house.	0
+2992	what's up. do you want me to come online?	0
+2993	a boy loved a gal. he propsd bt she didnt mind. he gv lv lttrs, bt her frnds threw thm. again d boy decided 2 aproach d gal , dt time a truck was speeding towards d gal. wn it was about 2 hit d girl,d boy ran like hell n saved her. she asked 'hw cn u run so fast?' d boy replied \boost is d secret of my energy\" n instantly d girl shouted \"our energy\" n thy lived happily 2gthr drinking boost evrydy moral of d story:- i hv free msgs:d;): gud ni8"	0
+2994	just now saw your message.it k da:)	0
+2995	dear sir,salam alaikkum.pride and pleasure meeting you today at the tea shop.we are pleased to send you our contact number at qatar.rakhesh an indian.pls save our number.respectful regards.	0
+2996	yes but i don't care cause i know its there!	0
+2997	then ur physics get a-?	0
+2998	pls go there today  &lt;#&gt; . i dont want any excuses	0
+2999	haha just kidding, papa needs drugs	0
+3000	yo my trip got postponed, you still stocked up?	0
+3001	hey ! don't forget ... you are mine ... for me ... my possession ... my property ... mmm ... *childish smile* ...	0
+3002	aight, sounds good. when do you want me to come down?	0
+3003	hi hun! im not comin 2nite-tell every1 im sorry 4 me, hope u ava goodtime!oli rang melnite ifink it mite b sorted,but il explain everythin on mon.l8rs.x	0
+3004	i'm in inside office..still filling forms.don know when they leave me.	0
+3005	urgent! we are trying to contact u todays draw shows that you have won a ??800 prize guaranteed. call 09050000460 from land line. claim j89. po box245c2150pm	1
+3006	no need to ke qi... ?? too bored izzit y suddenly thk of this...	0
+3007	thanks chikku..:-) gud nyt:-*	0
+3008	bull. your plan was to go floating off to ikea with me without a care in the world. so i have to live with your mess another day.	0
+3009	good morning plz call me sir	0
+3010	i will treasure every moment we spend together...	0
+3011	urgent! call 09066350750 from your landline. your complimentary 4* ibiza holiday or 10,000 cash await collection sae t&cs po box 434 sk3 8wp 150 ppm 18+	1
+3012	ummma.will call after check in.our life will begin from qatar so pls pray very hard.	0
+3013	some of them told accenture is not confirm. is it true.	0
+3014	hard live 121 chat just 60p/min. choose your girl and connect live. call 09094646899 now! cheap chat uk's biggest live service. vu bcm1896wc1n3xx	1
+3015	the current leading bid is 151. to pause this auction send out. customer care: 08718726270	1
+3016	customer loyalty offer:the new nokia6650 mobile from only ??10 at txtauction! txt word: start to no: 81151 & get yours now! 4t&ctxt tc 150p/mtmsg	1
+3017	painful words- \i thought being happy was the most toughest thing on earth... but	0
+3018	ok i msg u b4 i leave my house.	0
+3019	i'm going 2 orchard now laready me reaching soon. u reaching?	0
+3020	hi princess! thank you for the pics. you are very pretty. how are you?	0
+3021	hello. damn this christmas thing. i think i have decided to keep this mp3 that doesnt work.	0
+3022	i love you !!! you know? can you feel it? does it make your belly warm? i wish it does, my love ... i shall meet you in your dreams, ahmad ... *adoring kiss*	0
+3023	someone u know has asked our dating service 2 contact you! cant guess who? call 09058091854 now all will be revealed. po box385 m6 6wu	1
+3024	you know, wot people wear. t shirts, jumpers, hat, belt, is all we know. we r at cribbs	0
+3025	rose needs water, season needs change, poet needs imagination..my phone needs ur sms and i need ur lovely frndship forever....	0
+3026	somebody set up a website where you can play hold em using eve online spacebucks	0
+3027	you have 1 new message. call 0207-083-6089	1
+3028	get down in gandhipuram and walk to cross cut road. right side &lt;#&gt; street road and turn at first right.	0
+3029	gain the rights of a wife.dont demand it.i am trying as husband too.lets see	0
+3030	i have 2 sleeping bags, 1 blanket and paper and  phone details. anything else?	0
+3031	okie	0
+3032	give one miss from that number please	0
+3033	first answer my question.	0
+3034	yeah, probably here for a while	0
+3035	your bill at 3 is ??33.65 so thats not bad!	0
+3036	alright, i'll make sure the car is back tonight	0
+3037	then u drive lor.	0
+3038	december only! had your mobile 11mths+? you are entitled to update to the latest colour camera mobile for free! call the mobile update co free on 08002986906	1
+3039	not sure yet, still trying to get a hold of him	0
+3040	another month. i need chocolate weed and alcohol.	0
+3041	a little. meds say take once every 8 hours. it's only been 5 but pain is back. so i took another. hope i don't die	0
+3042	then she buying today? ?? no need to c meh...	0
+3043	congrats! 1 year special cinema pass for 2 is yours. call 09061209465 now! c suprman v, matrix3, starwars3, etc all 4 free! bx420-ip4-5we. 150pm. dont miss out!	1
+3044	boltblue tones for 150p reply poly# or mono# eg poly3 1. cha cha slide 2. yeah 3. slow jamz 6. toxic 8. come with me or stop 4 more tones txt more	1
+3045	that was random saw my old roomate on campus. he graduated	0
+3046	actually i deleted my old website..now i m blogging at magicalsongs.blogspot.com	0
+3047	here got ur favorite oyster... n got my favorite sashimi... ok lar i dun say already... wait ur stomach start rumbling...	0
+3048	just wait till end of march when el nino gets himself. oh.	0
+3049	?? neva tell me how i noe... i'm not at home in da aft wat...	0
+3050	claim a 200 shopping spree, just call 08717895698 now! have you won! mobstorequiz10ppm	1
+3051	bishan lar nearer... no need buy so early cos if buy now i gotta park my car...	0
+3052	it didnt work again oh. ok goodnight then. i.ll fix and have it ready by the time you wake up. you are very dearly missed have a good night sleep.	0
+3053	they don't put that stuff on the roads to keep it from getting slippery over there?	0
+3054	your chance to be on a reality fantasy show call now = 08707509020 just 20p per min ntt ltd, po box 1327 croydon cr9 5wb 0870 is a national = rate call.	1
+3055	hey babe, sorry i didn't get sooner. gary can come and fix it cause he thinks he knows what it is but he doesn't go as far a ptbo and he says it will cost  &lt;#&gt;  bucks. i don't know if it might be cheaper to find someone there ? we don't have any second hand machines at all right now, let me know what you want to do babe	0
+3056	\im at arestaurant eating squid! i will be out about 10:30 wanna dosomething or is that to late?\""	0
+3057	as a registered subscriber yr draw 4 a ??100 gift voucher will b entered on receipt of a correct ans. when are the next olympics. txt ans to 80062	1
+3058	09066362231 urgent! your mobile no 07xxxxxxxxx won a ??2,000 bonus caller prize on 02/06/03! this is the 2nd attempt to reach you! call 09066362231 asap!	1
+3059	ok that would b lovely, if u r sure. think about wot u want to do, drinkin, dancin, eatin, cinema, in, out, about... up to u! wot about ?	0
+3060	i am not having her number sir	0
+3061	sorry man my account's dry or i would, if you want we could trade back half or i could buy some shit with my credit card	0
+3062	dizzamn, aight i'll ask my suitemates when i get back	0
+3063	ha ha cool cool chikku chikku:-):-db-)	0
+3064	i'm okay. chasing the dream. what's good. what are you doing next.	0
+3065	my parents, my kidz, my friends n my colleagues. all screaming.. surprise !! and i was waiting on the sofa.. ... ..... ' naked...!	0
+3066	my house here e sky quite dark liao... if raining then got excuse not 2 run already rite... hee...	0
+3067	yup i'm free...	0
+3068	good morning my dear........... have a great &amp; successful day.	0
+3069	just wondering, the others just took off	0
+3070	i am at a party with alex nichols	0
+3071	could you not read me, my love ? i answered you	0
+3072	ur balance is now ??500. ur next question is: who sang 'uptown girl' in the 80's ? 2 answer txt ur answer to 83600. good luck!	1
+3073	ok... then r we meeting later?	0
+3074	hi chachi tried calling u now unable to reach u .. pl give me a missed cal once u c tiz msg  kanagu	0
+3075	yeah i can still give you a ride	0
+3076	what number do u live at? is it 11?	0
+3077	ok lor wat time ?_ finish?	0
+3078	none of that's happening til you get here though	0
+3079	ummmmmaah many many happy returns of d day my dear sweet heart.. happy birthday dear	0
+3080	sun ah... thk mayb can if dun have anythin on... thk have to book e lesson... e pilates is at orchard mrt u noe hor...	0
+3081	oh ho. is this the first time u use these type of words	0
+3082	sorry light turned green, i meant another friend wanted  &lt;#&gt;  worth but he may not be around	0
+3083	no you'll just get a headache trying to figure it out. u can trust me to do the math. i promise. o:-)	0
+3084	pls give her prometazine syrup. 5mls then  &lt;#&gt; mins later feed.	0
+3085	i thk 50 shd be ok he said plus minus 10.. did ?_ leave a line in between paragraphs?	0
+3086	i'm job profile seems like bpo..	0
+3087	* was thinking about chuckin ur red green n black trainners 2 save carryin them bac on train	0
+3088	hello my little party animal! i just thought i'd buzz you as you were with your friends ...*grins*... reminding you were loved and send a naughty adoring kiss	0
+3089	come to me, slave. your doing it again ... going into your shell and unconsciously avoiding me ... you are making me unhappy :-(	0
+3090	call to the number which is available in appointment. and ask to connect the call to waheed fathima.	0
+3091	fun fact: although you would think armand would eventually build up a tolerance or some shit considering how much he smokes, he gets fucked up in like 2 hits	0
+3092	pity, * was in mood for that. so...any other suggestions?	0
+3093	no way i'm going back there!	0
+3094	hey... very inconvenient for your sis a not huh?	0
+3095	dont know supports ass and srt i thnk. i think ps3 can play through usb too	0
+3096	o was not into fps then.	0
+3097	wish u many many returns of the day.. happy birthday vikky..	0
+3098	ugh fuck it i'm resubbing to eve	0
+3099	where are you ? what do you do ? how can you stand to be away from me ? doesn't your heart ache without me ? don't you wonder of me ? don't you crave me ?	0
+3100	lol yes. but it will add some spice to your day.	0
+3101	get ready for  &lt;#&gt;  inches of pleasure...	0
+3102	i didnt get anything da	0
+3103	oh... okie lor...we go on sat...	0
+3104	nah im goin 2 the wrks with j wot bout u?	0
+3105	if you don't respond imma assume you're still asleep and imma start calling n shit	0
+3106	1000's flirting now! txt girl or bloke & ur name & age, eg girl zoe 18 to 8007 to join and get chatting!	1
+3107	i was gonna ask you lol but i think its at 7	0
+3108	long time. you remember me today.	0
+3109	the basket's gettin full so i might be by tonight	0
+3110	i'll see, but prolly yeah	0
+3111	urgent! last weekend's draw shows that you have won ??1000 cash or a spanish holiday! call now 09050000332 to claim. t&c: rstm, sw7 3ss. 150ppm	1
+3112	babe? you said 2 hours and it's been almost 4 ... is your internet down ?	0
+3113	aight, text me tonight and we'll see what's up	0
+3114	uncle boye. i need movies oh. guide me. plus you know torrents are not particularly legal here. and the system is slowing down. what should i do. have a gr8 day. plus have you started cos i dont meet you online. how was the honey moon.	0
+3115	it's ?? only $140 ard...?? rest all ard $180 at least...which is ?? price 4 ?? 2 bedrm ($900)	0
+3116	beautiful truth against gravity.. read carefully: \our heart feels light when someone is in it.. but it feels very heavy when someone leaves it..\" goodmorning"	0
+3117	wanna get laid 2nite? want real dogging locations sent direct to ur mobile? join the uk's largest dogging network. txt park to 69696 now! nyt. ec2a. 3lp ??1.50/msg	1
+3118	hey i will be really pretty late... you want to go for the lesson first? i will join you. i'm only reaching tp mrt	0
+3119	congrats! nokia 3650 video camera phone is your call 09066382422 calls cost 150ppm ave call 3mins vary from mobiles 16+ close 300603 post bcm4284 ldn wc1n3xx	1
+3120	that's cool, i'll come by like  &lt;#&gt; ish	0
+3121	these won't do. have to move on to morphine	0
+3122	mathews or tait or edwards or anderson	0
+3123	ok...	0
+3124	moby pub quiz.win a ??100 high street prize if u know who the new duchess of cornwall will be? txt her first name to 82277.unsub stop ??1.50 008704050406 sp	1
+3125	when're you guys getting back? g said you were thinking about not staying for mcr	0
+3126	do you like italian food?	0
+3127	argh why the fuck is nobody in town ;_;	0
+3128	alrite	0
+3129	no 1 polyphonic tone 4 ur mob every week! just txt pt2 to 87575. 1st tone free ! so get txtin now and tell ur friends. 150p/tone. 16 reply hl 4info	1
+3130	that means from february to april i'll be getting a place to stay down there so i don't have to hustle back and forth during audition season as i have since my sister moved away from harlem.	0
+3131	lmao. take a pic and send it to me.	0
+3132	but i dint slept in afternoon.	0
+3133	you need to get up. now.	0
+3134	saw guys and dolls last night with patrick swayze it was great	0
+3135	speaking of does he have any cash yet?	0
+3136	hey!!! i almost forgot ... happy b-day babe ! i love ya!!	0
+3137	did u turn on the heater? the heater was on and set to &lt;#&gt; degrees.	0
+3138	any pain on urination any thing else?	0
+3139	todays voda numbers ending 5226 are selected to receive a ?350 award. if you hava a match please call 08712300220 quoting claim code 1131 standard rates app	1
+3140	that's what i love to hear :v see you sundayish, then	0
+3141	yup having my lunch buffet now.. u eat already?	0
+3142	do you want a new video phone750 anytime any network mins 150 text for only five pounds per week call 08000776320 now or reply for delivery tomorrow	1
+3143	ok no problem... yup i'm going to sch at 4 if i rem correctly...	0
+3144	sometimes heart remembrs someone very much... forgets someone soon... bcoz heart will not like everyone. but liked ones will be remembered everytime... bslvyl	0
+3145	?? ready then call me...	0
+3146	he will, you guys close?	0
+3147	mmmmmmm *snuggles into you* ...*deep contented sigh* ... *whispers* ... i fucking love you so much i can barely stand it ...	0
+3148	my friend, she's studying at warwick, we've planned to go shopping and to concert tmw, but it may be canceled, havn't seen  for ages, yeah we should get together sometime!	0
+3149	i probably won't eat at all today. i think i'm gonna pop. how was your weekend? did u miss me?	0
+3150	dad says hurry the hell up	0
+3151	horrible gal. me in sch doing some stuff. how come u got mc?	0
+3152	i not at home now lei...	0
+3153	anyway i don't think i can secure anything up here, lemme know if you want me to drive down south and chill	0
+3154	?? takin linear algebra today?	0
+3155	wiskey brandy rum gin beer vodka scotch shampain wine \kudi\"yarasu dhina vaazhthukkal. .."	0
+3156	honeybee said: *i'm d sweetest in d world* god laughed &amp; said: *wait,u havnt met d person reading this msg* moral: even god can crack jokes! gm+gn+ge+gn:)	0
+3157	kind of. just missed train cos of asthma attack, nxt one in half hr so driving in. not sure where to park.	0
+3158	dude while were makin those weirdy brownies my sister made awesome cookies. i took pics.	0
+3159	they finally came to fix the ceiling.	0
+3160	i'm not. she lip synced with shangela.	0
+3161	thanx 4 2day! u r a goodmate i think ur rite sary! asusual!1 u cheered me up! love u franyxxxxx	0
+3162	new textbuddy chat 2 horny guys in ur area 4 just 25p free 2 receive search postcode or at gaytextbuddy.com. txt one name to 89693. 08715500022 rpl stop 2 cnl	1
+3163	no da. . vijay going to talk in jaya tv	0
+3164	wah... okie okie... muz make use of e unlimited... haha...	0
+3165	pls i wont belive god.not only jesus.	0
+3166	please reserve ticket on saturday eve from chennai to thirunelvali and again from tirunelvali to chennai on sunday eve...i already see in net..no ticket available..i want to book ticket through tackle ..	0
+3167	ha! i wouldn't say that i just didn't read anything into way u seemed. i don't like 2 be judgemental....i save that for fridays in the pub!	0
+3168	and you! will expect you whenever you text! hope all goes well tomo	0
+3169	i wan but too early lei... me outside now wun b home so early... neva mind then...	0
+3170	i cant pick the phone right now. pls send a message	0
+3171	no shoot me. i'm in the docs waiting room. :/	0
+3172	brainless baby doll..:-d;-), vehicle sariyag drive madoke barolla..	0
+3173	as one of our registered subscribers u can enter the draw 4 a 100 g.b. gift voucher by replying with enter. to unsubscribe text stop	1
+3174	good evening! how are you?	0
+3175	do you know what mallika sherawat did yesterday? find out now @  &lt;url&gt;	0
+3176	hurry up, i've been weed-deficient for like three days	0
+3177	japanese proverb: if one can do it, u too can do it, if none can do it,u must do it indian version: if one can do it, let him do it.. if none can do it,leave it!! and finally kerala version: if one can do it, stop him doing it.. if none can do it, make a strike against it ...	0
+3178	don no da:)whats you plan?	0
+3179	right on brah, see you later	0
+3180	it vl bcum more difficult..	0
+3181	my life means a lot to me, not because i love my life, but because i love the people in my life, the world calls them friends, i call them my world:-).. ge:-)..	0
+3182	k go and sleep well. take rest:-).	0
+3183	men like shorter ladies. gaze up into his eyes.	0
+3184	holy living christ what is taking you so long	0
+3185	make that 3! 4 fucks sake?! x	0
+3186	hasn't that been the pattern recently crap weekends?	0
+3187	can i meet ?_ at 5.. as 4 where depends on where ?_ wan 2 in lor..	0
+3188	it's still not working. and this time i also tried adding zeros. that was the savings. the checking is  &lt;#&gt;	0
+3189	yes! the only place in town to meet exciting adult singles is now in the uk. txt chat to 86688 now! 150p/msg.	1
+3190	free unlimited hardcore porn direct 2 your mobile txt porn to 69200 & get free access for 24 hrs then chrgd@50p per day txt stop 2exit. this msg is free	1
+3191	want to finally have lunch today?	0
+3192	havent planning to buy later. i check already lido only got 530 show in e afternoon. u finish work already?	0
+3193	wishing you and your family merry \x\" mas and happy new year in advance.."	0
+3194	i think u have the wrong number.	0
+3195	no probably  &lt;#&gt; %.	0
+3196	this is the 2nd time we have tried to contact u. u have won the ??1450 prize to claim just call 09053750005 b4 310303. t&cs/stop sms 08718725756. 140ppm	1
+3197	well done england! get the official poly ringtone or colour flag on yer mobile! text tone or flag to 84199 now! opt-out txt eng stop. box39822 w111wx ??1.50	1
+3198	sounds like you have many talents! would you like to go on a dinner date next week?	0
+3199	no break time one... how... i come out n get my stuff fr ?_?	0
+3200	he is impossible to argue with and he always treats me like his sub, like he never released me ... which he did and i will remind him of that if necessary	0
+3201	haha... really oh no... how? then will they deduct your lesson tmr?	0
+3202	cos i was out shopping wif darren jus now n i called him 2 ask wat present he wan lor. then he started guessing who i was wif n he finally guessed darren lor.	0
+3203	want to send me a virtual hug?... i need one	0
+3204	prabha..i'm soryda..realy..frm heart i'm sory	0
+3205	you will be in the place of that man	0
+3206	oh shut it. omg yesterday i had a dream that i had 2 kids both boys. i was so pissed. not only about the kids but them being boys. i even told mark in my dream that he was changing diapers cause i'm not getting owed in the face.	0
+3207	ta-daaaaa! i am home babe, are you still up ?	0
+3208	t-mobile customer you may now claim your free camera phone upgrade & a pay & go sim card for your loyalty. call on 0845 021 3680.offer ends 28thfeb.t&c's apply	1
+3209	babe ? i lost you ... will you try rebooting ?	0
+3210	you are being ripped off! get your mobile content from www.clubmoby.com call 08717509990 poly/true/pix/ringtones/games six downloads for only 3	1
+3211	i love you. you set my soul on fire. it is not just a spark. but it is a flame. a big rawring flame. xoxo	0
+3212	no dude, its not fake..my frnds got money, thts y i'm reffering u..if u member wit my mail link, u vl be credited  &lt;#&gt; rs and il be getiing  &lt;#&gt; rs..i can draw my acc wen it is  &lt;#&gt; rs..	0
+3213	aiyah e rain like quite big leh. if drizzling i can at least run home.	0
+3214	take care n get well soon	0
+3215	here is my new address -apples&pairs&all that malarky	0
+3216	fuck babe ... what happened to you ? how come you never came back?	0
+3217	wot u up 2 j?	0
+3218	good afternon, my love. how are today? i hope your good and maybe have some interviews. i wake and miss you babe. a passionate kiss from across the sea	0
+3219	\hey das cool... iknow all 2 wellda peril of studentfinancial crisis!spk 2 u l8r.\""	0
+3220	why don't you wait 'til at least wednesday to see if you get your .	0
+3221	haha... dont be angry with yourself... take it as a practice for the real thing. =)	0
+3222	lol they were mad at first but then they woke up and gave in.	0
+3223	yoyyooo u know how to change permissions for a drive in mac. my usb flash drive	0
+3224	(that said can you text him one more time?)	0
+3225	only if you promise your getting out as soon as you can. and you'll text me in the morning to let me know you made it in ok.	0
+3226	then we wait 4 u lor... no need 2 feel bad lar...	0
+3227	east coast	0
+3228	dont think so. it turns off like randomlly within 5min of opening	0
+3229	lol ... have you made plans for new years?	0
+3230	what is the plural of the noun research?	0
+3231	que pases un buen tiempo or something like that	0
+3232	wamma get laid?want real doggin locations sent direct to your mobile? join the uks largest dogging network. txt dogs to 69696 now!nyt. ec2a. 3lp ??1.50/msg.	1
+3233	nothing spl..wat abt u and whr ru?	0
+3234	hey babe! i saw you came online for a second and then you disappeared, what happened ?	0
+3235	why i come in between you people	0
+3236	you know my old dom i told you about yesterday ? his name is roger? he got in touch with me last night and wants me to meet him today at 2 pm	0
+3237	are you going to wipro interview today?	0
+3238	he says he'll give me a call when his friend's got the money but that he's definitely buying before the end of the week	0
+3239	the world is running and i am still.maybe all are feeling the same,so be it.or i have to admit,i am mad.then where is the correction?or let me call this is life.and keep running with the world,may be u r also running.lets run.	0
+3240	do you still have the grinder?	0
+3241	already am squatting is the new way of walking	0
+3242	hot live fantasies call now 08707509020 just 20p per min ntt ltd, po box 1327 croydon cr9 5wb 0870..k	1
+3243	i'm at home. please call	0
+3244	sat right? okay thanks...	0
+3245	good night my dear.. sleepwell&amp;take care	0
+3246	i knew it... u slept v late yest? wake up so late...	0
+3247	83039 62735=??450 uk break accommodationvouchers terms & conditions apply. 2 claim you mustprovide your claim number which is 15541	1
+3248	you didn't have to tell me that...now i'm thinking. plus he's going to stop all your runs	0
+3249	watching telugu movie..wat abt u?	0
+3250	do you ever notice that when you're driving, anyone going slower than you is an idiot and everyone driving faster than you is a maniac?	1
+3251	ok i msg u b4 i leave my house.	0
+3252	hungry gay guys feeling hungry and up 4 it, now. call 08718730555 just 10p/min. to stop texts call 08712460324 (10p/min)	1
+3253	xxxmobilemovieclub: to use your credit, click the wap link in the next txt message or click here>> http://wap. xxxmobilemovieclub.com?n=qjkgighjjgcbl	1
+3254	at what time are you coming.	0
+3255	cos darren say ?_ considering mah so i ask ?_...	0
+3256	erm ??_ ill pick you up at about 6.45pm. that'll give enough time to get there, park and that.	0
+3257	the world's most happiest frnds never have the same characters... dey just have the best understanding of their differences...	0
+3258	and pls pls drink plenty plenty water	0
+3259	all boys made fun of me today. ok i have no problem. i just sent one message just for fun	0
+3260	wen ur lovable bcums angry wid u, dnt take it seriously.. coz being angry is d most childish n true way of showing deep affection, care n luv!.. kettoda manda... have nice day da.	0
+3261	oh mr sheffield! you wanna play that game, okay. you're the boss and i'm the nanny. you give me a raise and i'll give you one!!	0
+3262	jus finish my lunch on my way home lor... i tot u dun wan 2 stay in sch today...	0
+3263	get 3 lions england tone, reply lionm 4 mono or lionp 4 poly. 4 more go 2 www.ringtones.co.uk, the original n best. tones 3gbp network operator rates apply.	1
+3264	what i told before i tell. stupid hear after i wont tell anything to you. you dad called to my brother and spoken. not with me.	0
+3265	the xmas story is peace.. the xmas msg is love.. the xmas miracle is jesus.. hav a blessed month ahead &amp; wish u merry xmas...	0
+3266	don't forget who owns you and who's private property you are ... and be my good boy always .. *passionate kiss*	0
+3267	sorry, i'll call later	0
+3268	thanks for your ringtone order, reference number x49.your mobile will be charged 4.50. should your tone not arrive please call customer services 09065989182	1
+3269	i've told you everything will stop. just dont let her get dehydrated.	0
+3270	usually the body takes care of it buy making sure it doesnt progress. can we pls continue this talk on saturday.	0
+3271	my battery is low babe	0
+3272	hahaha..use your brain dear	0
+3273	hi good mornin.. thanku wish u d same..	0
+3274	shall i get my pouch?	0
+3275	what today-sunday..sunday is holiday..so no work..	0
+3276	sorry battery died, yeah i'm here	0
+3277	reminder: you have not downloaded the content you have already paid for. goto http://doit. mymoby. tv/ to collect your content.	1
+3278	i dun thk i'll quit yet... hmmm, can go jazz ? yogasana oso can... we can go meet em after our lessons den...	0
+3279	slept? i thinkthis time ( &lt;#&gt;  pm) is not dangerous	0
+3280	hi' test on  &lt;#&gt; rd ....	0
+3281	was it something u ate?	0
+3282	ok thats cool. its , just off either raglan rd or edward rd. behind the cricket ground. gimme ring when ur closeby see you tuesday.	0
+3283	ill call you evening ill some ideas.	0
+3284	check mail.i have mailed varma and kept copy to you regarding membership.take care.insha allah.	0
+3285	u wan 2 haf lunch i'm in da canteen now.	0
+3286	urgent!! your 4* costa del sol holiday or ??5000 await collection. call 09050090044 now toclaim. sae, tc s, pobox334, stockport, sk38xh, cost??1.50/pm, max10mins	1
+3287	should i tell my friend not to come round til like  &lt;#&gt; ish?	0
+3288	haha mayb u're rite... u know me well. da feeling of being liked by someone is gd lor. u faster go find one then all gals in our group attached liao.	0
+3289	i dnt wnt to tlk wid u	0
+3290	had the money issue weigh me down but thanks to you, i can breathe easier now. i.ll make sure you dont regret it. thanks.	0
+3291	more people are dogging in your area now. call 09090204448 and join like minded guys. why not arrange 1 yourself. there's 1 this evening. a??1.50 minapn ls278bb	1
+3292	who are you seeing?	0
+3293	hows the street where the end of library walk is?	0
+3294	she ran off with a younger man. we will make pretty babies together :)	0
+3295	hey gals.. anyone of u going down to e driving centre tmr?	0
+3296	ok lor.	0
+3297	7 wonders in my world 7th you 6th ur style 5th ur smile 4th ur personality 3rd ur nature 2nd ur sms and 1st \ur lovely friendship\"... good morning dear"	0
+3298	smile in pleasure smile in pain smile when trouble pours like rain smile when sum1 hurts u smile becoz someone still loves to see u smiling!!	0
+3299	urgent! your mobile no *********** won a ??2,000 bonus caller prize on 02/06/03! this is the 2nd attempt to reach you! call 09066362220 asap! box97n7qp, 150ppm	1
+3300	i sent you the prices and do you mean the  &lt;#&gt; g,	0
+3301	ok. i.ll do you right later.	0
+3302	ya even those cookies have jelly on them	0
+3303	you're not sure that i'm not trying to make xavier smoke because i don't want to smoke after being told i smoke too much?	0
+3304	hi darlin i hope you had a nice night i wish i had come cant wait to see you love fran ps i want dirty anal sex and i want a 10 man gang bang	0
+3305	this is the 2nd time we have tried 2 contact u. u have won the 750 pound prize. 2 claim is easy, call 08712101358 now! only 10p per min. bt-national-rate	1
+3306	aight, i'll text you when i'm back	0
+3307	i just made some payments so dont have that much. sorry. would you want it fedex or the other way.	0
+3308	will do. was exhausted on train this morning. too much wine and pie. you sleep well too	0
+3309	alex knows a guy who sells mids but he's down in south tampa and i don't think i could set it up before like 8	0
+3310	sorry i now then c ur msg... yar lor so poor thing... but only 4 one night... tmr u'll have a brand new room 2 sleep in...	0
+3311	&lt;#&gt;  w jetton ave if you forgot	0
+3312	your account has been credited with 500 free text messages. to activate, just txt the word: credit to no: 80488 t&cs www.80488.biz	1
+3313	wanna have a laugh? try chit-chat on your mobile now! logon by txting the word: chat and send it to no: 8883 cm po box 4217 london w1a 6zf 16+ 118p/msg rcvd	1
+3314	its a valentine game. . . send dis msg to all ur friends. . if 5 answers r d same then someone really loves u. . ques- which colour suits me the best?	0
+3315	for ur chance to win a ??250 wkly shopping spree txt: shop to 80878. t's&c's www.txt-2-shop.com custcare 08715705022, 1x150p/wk	1
+3316	dude ive been seeing a lotta corvettes lately	0
+3317	i cant pick the phone right now. pls send a message	0
+3318	watever relation u built up in dis world only thing which remains atlast iz lonlines with lotz n lot memories! feeling..	0
+3319	maybe i could get book out tomo then return it immediately ..? or something.	0
+3320	if i die i want u to have all my stuffs.	0
+3321	good morning princess! happy new year!	0
+3322	haha good to hear, i'm officially paid and on the market for an 8th	0
+3323	, ow u dey.i paid 60,400thousad.i told  u would call .	0
+3324	short but cute: \be a good person	0
+3325	what is important is that you prevent dehydration by giving her enough fluids	0
+3326	u meet other fren dun wan meet me ah... muz b a guy rite...	0
+3327	arngd marriage is while u r walkin unfortuntly a snake bites u. bt love marriage is dancing in frnt of d snake &amp; sayin bite me, bite me.	0
+3328	what not under standing.	0
+3329	i'm also came to room.	0
+3330	lol yep did that yesterday. already got my fireplace. now its just another icon sitting there for me.	0
+3331	ugh just got outta class	0
+3332	yes. nigh you cant aha.	0
+3333	am new 2 club & dont fink we met yet will b gr8 2 c u please leave msg 2day wiv ur area 09099726553 reply promised carlie x calls??1/minmobsmore lkpobox177hp51fl	1
+3334	have you got xmas radio times. if not i will get it now	0
+3335	sen told that he is going to join his uncle finance in cbe	0
+3336	i had askd u a question some hours before. its answer	0
+3337	i'm in a movie. call me 4 wat?	0
+3338	so li hai... me bored now da lecturer repeating last weeks stuff waste time...	0
+3339	did you show him and wot did he say or could u not c him 4 dust?	0
+3340	what happen dear tell me	0
+3341	are you this much buzy	0
+3342	u sure u can't take any sick time?	0
+3343	ok. i only ask abt e movie. u wan ktv oso?	0
+3344	nope, i'm still in the market	0
+3345	this weeks savamob member offers are now accessible. just call 08709501522 for details! savamob, pobox 139, la3 2wu. only ??1.50/week. savamob - offers mobile!	1
+3346	sir, i need velusamy sir's date of birth and company bank facilities details.	0
+3347	that is wondar full flim.	0
+3348	have you had a good day? mine was really busy are you up to much tomorrow night?	0
+3349	sexy sexy cum and text me im wet and warm and ready for some porn! u up for some fun? this msg is free recd msgs 150p inc vat 2 cancel text stop	1
+3350	mmm so yummy babe ... nice jolt to the suzy	0
+3351	dont flatter yourself... tell that man of mine two pints of carlin in ten minutes please....	0
+3352	seriously. tell her those exact words right now.	0
+3353	love has one law; make happy the person you love. in the same way friendship has one law; never make ur friend feel alone until you are alive.... gud night	0
+3354	huh means computational science... y they like dat one push here n there...	0
+3355	i cant pick the phone right now. pls send a message	0
+3356	do you like shaking your booty on the dance floor?	0
+3357	jay wants to work out first, how's 4 sound?	0
+3358	thank you princess! i want to see your nice juicy booty...	0
+3359	hmmm.still we dont have opener?	0
+3360	haven't eaten all day. i'm sitting here staring at this juicy pizza and i can't eat it. these meds are ruining my life.	0
+3361	night has ended for another day, morning has come in a special way. may you smile like the sunny rays and leaves your worries at the blue blue bay. gud mrng	0
+3362	i'd say that's a good sign but, well, you know my track record at reading women	0
+3363	university of southern california.	0
+3364	8007 25p 4 alfie moon's children in need song on ur mob. tell ur m8s. txt tone charity to 8007 for nokias or poly charity for polys :zed 08701417012 profit 2 charity	1
+3365	thanks again for your reply today. when is ur visa coming in. and r u still buying the gucci and bags. my sister things are not easy, uncle john also has his own bills so i really need to think about how to make my own money. later sha.	0
+3366	yup... i havent been there before... you want to go for the yoga? i can call up to book	0
+3367	understand. his loss is my gain :) so do you work? school?	0
+3368	i have lost 10 kilos as of today!	0
+3369	is avatar supposed to have subtoitles	0
+3370	\urgent! this is the 2nd attempt to contact u!u have won ??1000call 09071512432 b4 300603t&csbcm4235wc1n3xx.callcost150ppmmobilesvary. max??7. 50\""	1
+3371	dear i am not denying your words please	0
+3372	am on the uworld site. am i buying the qbank only or am i buying it with the self assessment also?	0
+3373	i hope your alright babe? i worry that you might have felt a bit desparate when you learned the job was a fake ? i am here waiting when you come back, my love	0
+3374	i take it we didn't have the phone callon friday. can we assume we won't have it this year now?	0
+3375	latest nokia mobile or ipod mp3 player +??400 proze guaranteed! reply with: win to 83355 now! norcorp ltd.??1,50/mtmsgrcvd18+	1
+3376	or i go home first lar ?_ wait 4 me lor.. i put down my stuff first..	0
+3377	sorry man, my stash ran dry last night and i can't pick up more until sunday	0
+3378	hiya, had a good day? have you spoken to since the weekend?	0
+3379	how much are we getting?	0
+3380	you are a big chic. common. declare	0
+3381	no we sell it all so we'll have tons if coins. then sell our coins to someone thru paypal. voila! money back in life pockets:)	0
+3382	yunny i'm walking in citylink now ?_ faster come down... me very hungry...	0
+3383	oh great. i.ll disturb him more so that we can talk.	0
+3384	that's y u haf 2 keep me busy...	0
+3385	our mobile number has won ??5000, to claim calls us back or ring the claims hot line on 09050005321.	1
+3386	u have a secret admirer. reveal who thinks u r so special. call 09065174042. to opt out reply reveal stop. 1.50 per msg recd. cust care 07821230901	1
+3387	networking job is there.	0
+3388	k..k...from tomorrow onwards started ah?	0
+3389	yes. they replied my mail. i'm going to the management office later. plus will in to bank later also.or on wednesday.	0
+3390	yeah sure i'll leave in a min	0
+3391	how come?	0
+3392	i think i've fixed it can you send a test message?	0
+3393	k i'll head out in a few mins, see you there	0
+3394	love isn't a decision, it's a feeling. if we could decide who to love, then, life would be much simpler, but then less magical	0
+3395	u r too much close to my heart. if u go away i will be shattered. plz stay with me.	0
+3396	yes! the only place in town to meet exciting adult singles is now in the uk. txt chat to 86688 now! 150p/msg.	1
+3397	and how you will do that, princess? :)	0
+3398	please da call me any mistake from my side sorry da. pls da goto doctor.	0
+3399	yeah, we can probably swing by once my roommate finishes up with his girl	0
+3400	u attend ur driving lesson how many times a wk n which day?	0
+3401	can you open the door?	0
+3402	hi frnd, which is best way to avoid missunderstding wit our beloved one's?	0
+3403	future is not what we planned for tomorrow.....! it is the result of what we do today...! do the best in present... enjoy the future.	0
+3404	congratulations ur awarded either a yrs supply of cds from virgin records or a mystery gift guaranteed call 09061104283 ts&cs www.smsco.net ??1.50pm approx 3mins	1
+3405	nope but i'm going home now then go pump petrol lor... like going 2 rain soon...	0
+3406	blank is blank. but wat is blank? lol	0
+3407	\gran onlyfound out afew days ago.cusoon honi\""	0
+3408	probably not, still going over some stuff here	0
+3409	stupid.its not possible	0
+3410	yes. it's all innocent fun. o:-)	0
+3411	thank you. i like you as well...	0
+3412	r we still meeting 4 dinner tonight?	0
+3413	lol ... i really need to remember to eat when i'm drinking but i do appreciate you keeping me company that night babe *smiles*	0
+3414	great to hear you are settling well. so what's happenin wit ola?	0
+3415	no one interested. may be some business plan.	0
+3416	geeeee ... i love you so much i can barely stand it	0
+3417	meet after lunch la...	0
+3418	was gr8 to see that message. so when r u leaving? congrats dear. what school and wat r ur plans.	0
+3419	was actually sleeping and still might when u call back. so a text is gr8. you rock sis. will send u a text wen i wake.	0
+3420	sian... aft meeting supervisor got work 2 do liao... u working now?	0
+3421	s:-)kallis wont play in first two odi:-)	0
+3422	life has never been this much fun and great until you came in. you made it truly special for me. i won't forget you! enjoy @ one gbp/sms	1
+3423	pass dis to all ur contacts n see wat u get! red;i'm in luv wid u. blue;u put a smile on my face. purple;u r realy hot. pink;u r so swt. orange;i thnk i lyk u. green;i realy wana go out wid u. yelow;i wnt u bck. black;i'm jealous of u. brown;i miss you nw plz giv me one color	0
+3424	thts wat wright brother did to fly..	0
+3425	mm you ask him to come its enough :-)	0
+3426	maybe you should find something else to do instead???	0
+3427	as usual..iam fine, happy &amp; doing well..:)	0
+3428	you have won! as a valued vodafone customer our computer has picked you to win a ??150 prize. to collect is easy. just call 09061743386	1
+3429	sorry i din lock my keypad.	0
+3430	k, wen ur free come to my home and also tel vikky i hav sent mail to him also.. better come evening il be free today aftr 6pm..:-)	0
+3431	babe, i'm back ... come back to me ...	0
+3432	am i that much dirty fellow?	0
+3433	will purchase d stuff today and mail to you. do you have a po box number?	0
+3434	quite ok but a bit ex... u better go eat smth now else i'll feel guilty...	0
+3435	pls help me tell sura that i'm expecting a battery from hont. and that if should pls send me a message about how to download movies. thanks	0
+3436	i have gone into get info bt dont know what to do	0
+3437	hey chief, can you give me a bell when you get this. need to talk to you about this royal visit on the 1st june.	0
+3438	somebody should go to andros and steal ice	0
+3439	may i call you later pls	0
+3440	fuck cedar key and fuck her (come over anyway tho)	0
+3441	what's nannys address?	0
+3442	congrats! 2 mobile 3g videophones r yours. call 09063458130 now! videochat wid your mates, play java games, dload polyph music, noline rentl.	1
+3443	huh so slow i tot u reach long ago liao... u 2 more days only i 4 more leh...	0
+3444	u horrible gal... u knew dat i was going out wif him yest n u still come n ask me...	0
+3445	spoons it is then okay?	0
+3446	hey we can go jazz power yoga hip hop kb and yogasana	0
+3447	al he does is moan at me if n e thin goes wrong its my fault&al de arguments r my fault&fed up of him of himso y bother? hav 2go, thanx.xx	0
+3448	lol! nah wasn't too bad thanks. its good to b home but its been quite a reality check. hows ur day been? did u do anything with website?	0
+3449	the length is e same but e top shorter n i got a fringe now. i thk i'm not going liao. too lazy. dun wan 2 distract u also.	0
+3450	howz that persons story	0
+3451	1's reach home call me.	0
+3452	no current and food here. i am alone also	0
+3453	ok	0
+3454	yup i thk cine is better cos no need 2 go down 2 plaza mah.	0
+3455	not for possession, especially not first offense	0
+3456	and by when you're done i mean now	0
+3457	yetunde, i'm sorry but moji and i seem too busy to be able to go shopping. can you just please find some other way to get what you wanted us to get. please forgive me. you can reply free via yahoo messenger.	0
+3458	sorry . i will be able to get to you. see you in the morning.	0
+3459	my computer just fried the only essential part we don't keep spares of because my fucking idiot roommates looovvve leaving the thing running on full  &lt;#&gt; /7	0
+3460	u're welcome... caught u using broken english again...	0
+3461	my drive can only be read. i need to write	0
+3462	aiyar sorry lor forgot 2 tell u...	0
+3463	sorry, i'll call later	0
+3464	sez, hows u & de arab boy? hope u r all good give my love 2 evry1 love ya eshxxxxxxxxxxx	0
+3465	sounds good, keep me posted	0
+3466	call germany for only 1 pence per minute! call from a fixed line via access number 0844 861 85 85. no prepayment. direct access! www.telediscount.co.uk	1
+3467	dear u've been invited to xchat. this is our final attempt to contact u! txt chat to 86688 150p/msgrcvdhg/suite342/2lands/row/w1j6hl ldn 18 yrs	1
+3468	you have been selected to stay in 1 of 250 top british hotels - for nothing! holiday worth ??350! to claim, call london 02072069400. bx 526, sw73ss	1
+3469	k:)k:)good:)study well.	0
+3470	otherwise had part time job na-tuition..	0
+3471	it wont b until 2.15 as trying 2 sort house out, is that ok?	0
+3472	\cha quiteamuzing that??scool babe	0
+3473	that's one of the issues but california is okay. no snow so its manageable	0
+3474	okie.. thanx..	0
+3475	appt is at &lt;time&gt; am. not my fault u don't listen. i told u twice	0
+3476	buzzzz! *grins* did i buzz your ass? buzz your chest ? buzz your cock ? where do you keep your phone ? is the vibrator on ? did you feel it shake ?	0
+3477	great comedy..cant stop laughing da:)	0
+3478	no that just means you have a fat head	0
+3479	we're on the opposite side from where we dropped you off	0
+3480	hi ....my engagement has been fixd on  &lt;#&gt; th of next month. i know its really shocking bt....hmm njan vilikkam....t ws al of a sudn;-(.	0
+3481	hey u still at the gym?	0
+3482	yup n her fren lor. i'm meeting my fren at 730.	0
+3483	yes... i trust u to buy new stuff asap so i can try it out	0
+3484	by the way, i've put a skip right outside the front of the house so you can see which house it is. just pull up before it.	0
+3485	honeybee said: *i'm d sweetest in d world* god laughed &amp; said: *wait,u havnt met d person reading this msg* moral: even god can crack jokes! gm+gn+ge+gn:)	0
+3486	88066 from 88066 lost 3pound help	1
+3487	no. i.ll meet you in the library	0
+3488	anything lor. juz both of us lor.	0
+3489	me, i dont know again oh	0
+3490	in meeting da. i will call you	0
+3491	if you aren't here in the next  &lt;#&gt;  hours imma flip my shit	0
+3492	i went to project centre	0
+3493	this is wishing you a great day. moji told me about your offer and as always i was speechless. you offer so easily to go to great lengths on my behalf and its stunning. my exam is next friday. after that i will keep in touch more. sorry.	0
+3494	what you doing?how are you?	0
+3495	i wish things were different. i wonder when i will be able to show you how much i value you. pls continue the brisk walks no drugs without askin me please and find things to laugh about. i love you dearly.	0
+3496	lol no. u can trust me.	0
+3497	sorry vikky, i'm watching olave mandara movie kano in trishul theatre wit my frnds..	0
+3498	dont kick coco when he's down	0
+3499	how was txting and driving	0
+3500	am i the only one who doesn't stalk profiles?	0
+3501	good morning plz call me sir	0
+3502	detroit. the home of snow. enjoy it.	0
+3503	tell them u have a headache and just want to use 1 hour of sick time.	0
+3504	it could work, we'll reach a consensus at the next meeting	0
+3505	in fact when do you leave? i think addie goes back to school tues or wed	0
+3506	good afternoon starshine! how's my boytoy? does he crave me yet? ache to fuck me ? *sips cappuccino* i miss you babe *teasing kiss*	0
+3507	when you just put in the + sign, choose my number and the pin will show. right?	0
+3508	great new offer - double mins & double txt on best orange tariffs and get latest camera phones 4 free! call mobileupd8 free on 08000839402 now! or 2stoptxt t&cs	1
+3509	hmm...bad news...hype park plaza $700 studio taken...only left 2 bedrm-$900...	0
+3510	ditto. and you won't have to worry about me saying anything to you anymore. like i said last night, you do whatever you want and i'll do the same. peace.	0
+3511	that day ?_ say ?_ cut ur hair at paragon, is it called hair sense? do ?_ noe how much is a hair cut?	0
+3512	en chikku nange bakra msg kalstiya..then had tea/coffee?	0
+3513	somewhere out there beneath the pale moon light someone think in of u some where out there where dreams come true... goodnite &amp; sweet dreams	0
+3514	s:)no competition for him.	0
+3515	s:)8 min to go for lunch:)	0
+3516	g wants to know where the fuck you are	0
+3517	i've reached home finally...	0
+3518	yep, by the pretty sculpture	0
+3519	wishing you and your family merry \x\" mas and happy new year in advance.."	0
+3520	meanwhile in the shit suite: xavier decided to give us  &lt;#&gt;  seconds of warning that samantha was coming over and is playing jay's guitar to impress her or some shit. also i don't think doug realizes i don't live here anymore	0
+3521	minimum walk is 3miles a day.	0
+3522	big brother???s really scraped the barrel with this shower of social misfits	0
+3523	win: we have a winner! mr. t. foley won an ipod! more exciting prizes soon, so keep an eye on ur mobile or visit www.win-82050.co.uk	1
+3524	how are you holding up?	0
+3525	piss is talking is someone that realise u that point this at is it.(now read it backwards)	0
+3526	what are you doing in langport? sorry, but i'll probably be in bed by 9pm. it sucks being ill at xmas! when do you and go2sri lanka?	0
+3527	ew are you one of them?	0
+3528	what * u wearing?	0
+3529	did u fix the teeth?if not do it asap.ok take care.	0
+3530	mm i am on the way to railway	0
+3531	dear got bus directly to calicut	0
+3532	we know taj mahal as symbol of love. but the other lesser known facts 1. mumtaz was shahjahan's 4th wife, out of his 7 wifes. 2. shahjahan killed mumtaz's husband to marry her. 3. mumtaz died in her  &lt;#&gt; th delivery. 4. he then married mumtaz's sister. question arises where the hell is the love?:-| -the great hari-	0
+3533	just checked out, heading out to drop off my stuff now	0
+3534	nice.nice.how is it working?	0
+3535	i calls you later. afternoon onwords mtnl service get problem in south mumbai. i can hear you but you cann't listen me.	0
+3536	do you want a new video handset? 750 anytime any network mins? half price line rental? camcorder? reply or call 08000930705 for delivery tomorrow	1
+3537	aiyah ok wat as long as got improve can already wat...	0
+3538	how did you find out in a way that didn't include all of these details	0
+3539	later i guess. i needa do mcat study too.	0
+3540	aight, lemme know what's up	0
+3541	bill, as in: are there any letters for me. i???m expecting one from orange that isn???t a bill but may still say orange on it.	0
+3542	ok i will tell her to stay out. yeah its been tough but we are optimistic things will improve this month.	0
+3543	urgent! we are trying to contact you. last weekends draw shows that you have won a ??900 prize guaranteed. call 09061701939. claim code s89. valid 12hrs only	1
+3544	i'm in town now so i'll jus take mrt down later.	0
+3545	i will come to ur home now	0
+3546	yes! how is a pretty lady like you single?	0
+3547	s but not able to sleep.	0
+3548	i???ll have a look at the frying pan in case it???s cheap or a book perhaps. no that???s silly a frying pan isn???t likely to be a book	0
+3549	is ur paper in e morn or aft tmr?	0
+3550	what do u reckon as need 2 arrange transport if u can't do it, thanks	0
+3551	no other valentines huh? the proof is on your fb page. ugh i'm so glad i really didn't watch your rupaul show you tool!	0
+3552	hm good morning, headache anyone? :-)	0
+3553	doesn't g have class early tomorrow and thus shouldn't be trying to smoke at  &lt;#&gt;	0
+3554	wat makes some people dearer is not just de happiness dat u feel when u meet them but de pain u feel when u miss dem!!!	0
+3555	hhahhaahahah rofl wtf nig was leonardo in your room or something	0
+3556	you are a winner you have been specially selected to receive ??1000 cash or a ??2000 award. speak to a live operator to claim call 087147123779am-7pm. cost 10p	1
+3557	yetunde i'm in class can you not run water on it to make it ok. pls now.	0
+3558	\thinking of u ;) x\""	0
+3559	ok lor. i'm in town now lei.	0
+3560	no drama pls.i have had enough from you and family while i am struggling in the hot sun in a strange place.no reason why there should be an ego of not going 'if not invited' when actually its necessity to go.wait for very serious reppurcussions.	0
+3561	he's really into skateboarding now despite the fact that he gets thrown off of it and winds up with bandages and shit all over his arms every five minutes	0
+3562	say this slowly.? god,i love you &amp; i need you,clean my heart with your blood.send this to ten special people &amp; u c miracle tomorrow, do it,pls,pls do it...	0
+3563	huh? 6 also cannot? then only how many mistakes?	0
+3564	\wen u miss someone  why to miss them  just keep-in-touch\" gdeve.."	0
+3565	sorry about earlier. putting out fires.are you around to talk after 9? or do you actually have a life, lol!	0
+3566	we will meet soon princess! ttyl!	0
+3567	ill b down soon	0
+3568	want to funk up ur fone with a weekly new tone reply tones2u 2 this text. www.ringtones.co.uk, the original n best. tones 3gbp network operator rates apply	1
+3569	haha, my friend tyler literally just asked if you could get him a dubsack	0
+3570	i dled 3d its very imp	0
+3571	are you ok. what happen to behave like this	0
+3572	damn, can you make it tonight or do you want to just wait til tomorrow	0
+3573	+123 congratulations - in this week's competition draw u have won the ??1450 prize to claim just call 09050002311 b4280703. t&cs/stop sms 08718727868. over 18 only 150ppm	1
+3574	then mum's repent how?	0
+3575	ever green quote ever told by jerry in cartoon \a person who irritates u always is the one who loves u vry much but fails to express it...!..!! :-) :-) gud nyt"	0
+3576	no da if you run that it activate the full version da.	0
+3577	haiyoh... maybe your hamster was jealous of million	0
+3578	senthil group company apnt 5pm.	0
+3579	hottest pics straight to your phone!! see me getting wet and wanting, just for you xx text pics to 89555 now! txt costs 150p textoperator g696ga 18 xxx	1
+3580	there generally isn't one. it's an uncountable noun - u in the dictionary. pieces of research?	0
+3581	today iz yellow rose day. if u love my frndship give me 1 misscall &amp; send this to ur frndz &amp; see how many miss calls u get. if u get 6missed u marry ur lover.	0
+3582	ya:)going for restaurant..	0
+3583	oh did you charge camera	0
+3584	my uncles in atlanta. wish you guys a great semester.	0
+3585	how is my boy? no sweet words left for me this morning ... *sighs* ... how goes you day, my love ? did you start your studying?	0
+3586	yep, the great loxahatchee xmas tree burning of  &lt;#&gt;  starts in an hour	0
+3587	:-) yeah! lol. luckily i didn't have a starring role like you!	0
+3588	how do friends help us in problems? they give the most stupid suggestion that lands us into another problem and helps us forgt the previous problem	0
+3589	abeg, make profit. but its a start. are you using it to get sponsors for the next event?	0
+3590	hi im having the most relaxing time ever! we have to get up at 7am every day! was the party good the other night? i get home tomorrow at 5ish.	0
+3591	camera - you are awarded a sipix digital camera! call 09061221066 fromm landline. delivery within 28 days	1
+3592	i've got it down to a tea. not sure which flavour	0
+3593	yeah like if it goes like it did with my friends imma flip my shit in like half an hour	0
+3594	we don call like  &lt;#&gt;  times oh. no give us hypertension oh.	0
+3595	armand says get your ass over to epsilon	0
+3596	* will have two more cartons off u and is very pleased with shelves	0
+3597	oh really?? did you make it on air? what's your talent?	0
+3598	i love you both too :-)	0
+3599	you're right i have now that i think about it	0
+3600	i dont have any of your file in my bag..i was in work when you called me.i 'll tell you if i find anything in my room.	0
+3601	thanx but my birthday is over already.	0
+3602	do u noe wat time e place dat sells 4d closes?	0
+3603	ok... ur typical reply...	0
+3604	no i'm not gonna be able to. || too late notice. || i'll be home in a few weeks anyway. || what are the plans	0
+3605	urgent! your mobile number has been awarded with a ??2000 prize guaranteed. call 09061790121 from land line. claim 3030. valid 12hrs only 150ppm	1
+3606	monthly password for wap. mobsi.com is 391784. use your wap phone not pc.	1
+3607	msgs r not time pass.they silently say that i am thinking of u right now and also making u think of me at least 4 a moment. gd nt.swt drms @shesil	0
+3608	hi the way i was with u 2day, is the normal way&this is the real me. ur unique&i hope i know u 4 the rest of mylife. hope u find wot was lost.	0
+3609	please ask mummy to call father	0
+3610	hey, i missed you tm of last night as my phone was on the charge ... *smiles* ... i am meeting a friend shortly	0
+3611	those were my exact intentions	0
+3612	you are a winner u have been specially selected 2 receive ??1000 cash or a 4* holiday (flights inc) speak to a live operator 2 claim 0871277810810	1
+3613	only once then after ill obey all yours.	0
+3614	do you hide anythiing or keeping distance from me	0
+3615	until 545 lor... ya, can go 4 dinner together...	0
+3616	the guy (kadeem) hasn't been selling since the break, i know one other guy but he's paranoid as fuck and doesn't like selling without me there and i can't be up there til late tonight	0
+3617	hey you still want to go for yogasana? coz if we end at cine then can go bathe and hav the steam bath	0
+3618	send a logo 2 ur lover - 2 names joined by a heart. txt love name1 name2 mobno eg love adam eve 07123456789 to 87077 yahoo! pobox36504w45wq txtno 4 no ads 150p	1
+3619	thank u!	0
+3620	for you information, ikea is spelled with all caps. that is not yelling. when you thought i had left you, you were sitting on the bed among the mess when i came in. i said we were going after you got home from class. please don't try and bullshit me. it makes me want to listen to you less.	0
+3621	i wnt to buy a bmw car urgently..its vry urgent.but hv a shortage of  &lt;#&gt; lacs.there is no source to arng dis amt. &lt;#&gt; lacs..thats my prob	0
+3622	smsservices. for yourinclusive text credits, pls goto www.comuk.net login= 3qxj9 unsubscribe with stop, no extra charge. help 08702840625.comuk. 220-cm2 9ae	1
+3623	just sleeping..and surfing	0
+3624	babe, i need your advice	0
+3625	u have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09058094599	1
+3626	lol ok your forgiven :)	0
+3627	where you. what happen	0
+3628	not sure i have the stomach for it ...	0
+3629	yeah whatever lol	0
+3630	yeah just open chat and click friend lists. then make the list. easy as pie	0
+3631	that's good. lets thank god. please complete the drug. have lots of water. and have a beautiful day.	0
+3632	do u konw waht is rael friendship im gving yuo an exmpel: jsut ese tihs msg.. evrey splleing of tihs msg is wrnog.. bt sitll yuo can raed it wihtuot ayn mitsake.. goodnight &amp; have a nice sleep..sweet dreams..	0
+3633	yup i thk they r e teacher said that will make my face look longer. darren ask me not 2 cut too short.	0
+3634	oic... i saw him too but i tot he din c me... i found a group liao...	0
+3635	when u love someone dont make them to love u as much as u do. but love them so much that they dont want to be loved by anyone except you... gud nit.	0
+3636	are you this much buzy	0
+3637	free message activate your 500 free text messages by replying to this message with the word free for terms & conditions, visit www.07781482378.com	1
+3638	u are subscribed to the best mobile content service in the uk for ??3 per 10 days until you send stop to 82324. helpline 08706091795	1
+3639	bought one ringtone and now getting texts costing 3 pound offering more tones etc	1
+3640	dont pick up d call when something important is there to tell. hrishi	0
+3641	my mobile number.pls sms ur mail id.convey regards to achan,amma.rakhesh.qatar	0
+3642	she's good. how are you. where r u working now	0
+3643	at home watching tv lor.	0
+3644	we are at grandmas. oh dear, u still ill? i felt shit this morning but i think i am just hungover! another night then. we leave on sat.	0
+3645	i am in office:)whats the matter..msg me now.i will call you at break:).	0
+3646	onum ela pa. normal than.	0
+3647	at home by the way	0
+3648	hope ur head doesn't hurt 2 much ! am ploughing my way through a pile of ironing ! staying in with a chinky tonight come round if you like.	0
+3649	what does the dance river do?	0
+3650	hope you are not scared!	0
+3651	;-) ok. i feel like john lennon.	0
+3652	just getting back home	0
+3653	urgent! please call 09061743811 from landline. your abta complimentary 4* tenerife holiday or ??5000 cash await collection sae t&cs box 326 cw25wx 150ppm	1
+3654	try neva mate!!	0
+3655	i'm eatin now lor, but goin back to work soon... e mountain deer show huh... i watch b4 liao, very nice...	0
+3656	why tired what special there you had	0
+3657	poyyarikatur,kolathupalayam,unjalur post,erode dis, &lt;#&gt; .	0
+3658	i thk u dun haf 2 hint in e forum already lor... cos i told ron n darren is going 2 tell shuhui.	0
+3659	you are a great role model. you are giving so much and i really wish each day for a miracle but god as a reason for everything and i must say i wish i knew why but i dont. i've looked up to you since i was young and i still do. have a great day.	0
+3660	where is it. is there any opening for mca.	0
+3661	sorry, i'll call later	0
+3662	your gonna be the death if me. i'm gonna leave a note that says its all robs fault. avenge me.	0
+3663	o we cant see if we can join denis and mina? or does denis want alone time	0
+3664	okie...	0
+3665	upgrdcentre orange customer, you may now claim your free camera phone upgrade for your loyalty. call now on 0207 153 9153. offer ends 26th july. t&c's apply. opt-out available	1
+3666	i wake up long ago already... dunno, what other thing?	0
+3667	yes ammae....life takes lot of turns you can only sit and try to hold the steering...	0
+3668	sorry, i'll call later	0
+3669	aiyah u did ok already lar. e nydc at wheellock?	0
+3670	what to think no one saying clearly. ok leave no need to ask her. i will go if she come or not	0
+3671	i asked you to call him now ok	0
+3672	dear matthew please call 09063440451 from a landline, your complimentary 4*lux tenerife holiday or ??1000 cash await collection. ppm150 sae t&cs box334 sk38xh.	1
+3673	s.i think he is waste for rr..	0
+3674	we got a divorce. lol. she.s here	0
+3675	hello! just got here, st andrews-boy its a long way! its cold. i will keep you posted	0
+3676	ok. so april. cant wait	0
+3677	how long does applebees fucking take	0
+3678	?? dun wan to watch infernal affair?	0
+3679	okay same with me. well thanks for the clarification	0
+3680	how have your little darlings been so far this week? need a coffee run tomo?can't believe it's that time of week already ??_	0
+3681	dont search love, let love find u. thats why its called falling in love, bcoz u dont force yourself, u just fall and u know there is smeone to hold u... bslvyl	0
+3682	pls send me the correct name da.	0
+3683	it certainly puts things into perspective when something like this happens	0
+3684	we tried to contact you re your response to our offer of a new nokia fone and camcorder hit reply or call 08000930705 for delivery	1
+3685	private! your 2003 account statement for 07753741225 shows 800 un-redeemed s. i. m. points. call 08715203677 identifier code: 42478 expires 24/10/04	1
+3686	hows that watch resizing	0
+3687	win the newest ???harry potter and the order of the phoenix (book 5) reply harry, answer 5 questions - chance to be the first among readers!	1
+3688	i guess you could be as good an excuse as any, lol.	0
+3689	sunshine hols. to claim ur med holiday send a stamped self address envelope to drinks on us uk, po box 113, bray, wicklow, eire. quiz starts saturday! unsub stop	1
+3690	aft i finish my lunch then i go str down lor. ard 3 smth lor. u finish ur lunch already?	0
+3691	hello darlin ive finished college now so txt me when u finish if u can love kate xxx	0
+3692	easy ah?sen got selected means its good..	0
+3693	\hi darlin did youphone me? im athome if youwanna chat.\""	0
+3694	omg how did u know what i ate?	0
+3695	ambrith..madurai..met u in arun dha marrge..remembr?	0
+3696	haha, just what i was thinkin	0
+3697	its so common hearin how r u? wat r u doing? how was ur day? so let me ask u something different. did u smile today? if not, do it now.... gud evng.	0
+3698	what is your account number?	0
+3699	sac needs to carry on:)	0
+3700	yun ah.now ?_ wkg where?btw if ?_ go nus sc. ?? wana specialise in wad?	0
+3701	its worse if if uses half way then stops. its better for him to complete it.	0
+3702	the lay man! just to let you know you are missed and thought off. do have a great day. and if you can send me bimbo and ugo's numbers, ill appreciate. safe	0
+3703	no message..no responce..what happend?	0
+3704	i'm very happy for you babe ! woo hoo party on dude!	0
+3705	when i was born, god said, \oh no! another idiot\". when you were born  \"oh no! competition\". who knew  one day these two will become freinds forever!"	0
+3706	im done. just studyn in library	0
+3707	yes da. any plm at ur office	0
+3708	you'll never believe this but i have actually got off at taunton. wow	0
+3709	sorry, went to bed early, nightnight	0
+3710	wish i were with you now!	0
+3711	wat u doing there?	0
+3712	yeah that's what i thought, lemme know if anything's goin on later	0
+3713	oh oh... den muz change plan liao... go back have to yan jiu again...	0
+3714	hi this is yijue... it's regarding the 3230 textbook it's intro to algorithms second edition... i'm selling it for $50...	0
+3715	okie...	0
+3716	2mro i am not coming to gym machan. goodnight.	0
+3717	nvm... i'm going to wear my sport shoes anyway... i'm going to be late leh.	0
+3718	wait, do you know if wesleys in town? i bet she does hella drugs!	0
+3719	oh for fuck's sake she's in like tallahassee	0
+3720	private! your 2004 account statement for 07742676969 shows 786 unredeemed bonus points. to claim call 08719180248 identifier code: 45239 expires	1
+3721	u have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09058094599	1
+3722	thought we could go out for dinner. i'll treat you! seem ok?	0
+3723	you call him now ok i said call him	0
+3724	as a sim subscriber, you are selected to receive a bonus! get it delivered to your door, txt the word ok to no: 88600 to claim. 150p/msg, exp. 30apr	1
+3725	i can. but it will tell quite long, cos i haven't finish my film yet...	0
+3726	genius what's up. how your brother. pls send his number to my skype.	0
+3727	i'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k? i've cried enough today.	0
+3728	your account has been refilled successfully by inr  &lt;decimal&gt; . your keralacircle prepaid account balance is rs  &lt;decimal&gt; . your transaction id is kr &lt;#&gt; .	0
+3729	sms auction you have won a nokia 7250i. this is what you get when you win our free auction. to take part send nokia to 86021 now. hg/suite342/2lands row/w1jhl 16+	1
+3730	it's ok lar. u sleep early too... nite...	0
+3731	ok.ok ok..then..whats ur todays plan	0
+3732	had your mobile 11 months or more? u r entitled to update to the latest colour mobiles with camera for free! call the mobile update co free on 08002986030	1
+3733	hi dear we saw dear. we both are happy. where you my battery is low	0
+3734	the whole car appreciated the last two! dad and are having a map reading semi argument but apart from that things are going ok. p.	0
+3735	today my system sh get ready.all is well and i am also in the deep well	0
+3736	how long has it been since you screamed, princess?	0
+3737	near kalainar tv office.thenampet	0
+3738	aight no rush, i'll ask jay	0
+3739	as a valued customer, i am pleased to advise you that following recent review of your mob no. you are awarded with a ??1500 bonus prize, call 09066364589	1
+3740	wait 2 min..stand at bus stop	0
+3741	thanks for this hope you had a good day today	0
+3742	for the first time in the history 'need' 'comfort' and 'luxury' are sold at same price in india..!! onion-rs. &lt;#&gt;  petrol-rs. &lt;#&gt;  beer-rs. &lt;#&gt;  shesil  &lt;#&gt;	0
+3743	ok lor.	0
+3744	sad story of a man - last week was my b'day. my wife did'nt wish me. my parents forgot n so did my kids . i went to work. even my colleagues did not wish. as i entered my cabin my pa said, '' happy b'day boss !!''. i felt special. she askd me 4 lunch. after lunch she invited me to her apartment. we went there. she said,'' do u mind if i go into the bedroom for a minute ? '' ''ok'', i sed in a sexy mood. she came out 5 minuts latr wid a cake...n my wife, my parents, my kidz, my friends n my colleagues. all screaming.. surprise !! and i was waiting on the sofa.. ... ..... ' naked...!	0
+3745	you also didnt get na hi hi hi hi hi	0
+3746	from someone not to smoke when every time i've smoked in the last two weeks is because of you calling or texting me that you wanted to smoke	0
+3747	that sucks. so what do you got planned for your yo valentine? i am your yo valentine aren't i?	0
+3748	yup i'm still having coffee wif my frens... my fren drove she'll give me a lift...	0
+3749	pls what's the full name of joke's school cos fees in university of florida seem to actually be  &lt;#&gt; k. pls holla back	0
+3750	hai priya are you right. what doctor said pa. where are you.	0
+3751	i am in tirupur.  call you da.	0
+3752	?? called dad oredi...	0
+3753	okay. i've seen it. so i should pick it on friday?	0
+3754	small problem in auction:)punj now asking tiwary	0
+3755	you've always been the brainy one.	0
+3756	i'm still looking for a car to buy. and have not gone 4the driving test yet.	0
+3757	i tot u reach liao. he said t-shirt.	0
+3758	i havent add ?_ yet right..	0
+3759	we not watching movie already. xy wants 2 shop so i'm shopping w her now.	0
+3760	shuhui say change 2 suntec steamboat? u noe where? where r u now?	0
+3761	sorry, i'll call you  later. i am in meeting sir.	0
+3762	thats cool. how was your day?	0
+3763	are you in town? this is v. important	0
+3764	you are awarded a sipix digital camera! call 09061221061 from landline. delivery within 28days. t cs box177. m221bp. 2yr warranty. 150ppm. 16 . p p??3.99	1
+3765	u r subscribed 2 textcomp 250 wkly comp. 1st wk?s free question follows, subsequent wks charged@150p/msg.2 unsubscribe txt stop 2 84128,custcare 08712405020	1
+3766	i can't make it tonight	0
+3767	natalja (25/f) is inviting you to be her friend. reply yes-440 or no-440 see her: www.sms.ac/u/nat27081980 stop? send stop frnd to 62468	1
+3768	yup... hey then one day on fri we can ask miwa and jiayin take leave go karaoke	0
+3769	i've told him that i've returned it. that should i re order it.	0
+3770	okey dokey swashbuckling stuff what oh.	0
+3771	aight sorry i take ten years to shower. what's the plan?	0
+3772	good! no, don???t need any receipts???well done! (??_) yes, please tell . what???s her number, i could ring her	0
+3773	oh gei. that happend to me in tron. maybe ill dl it in 3d when its out	0
+3774	then ?_ ask dad to pick ?_ up lar... ?? wan 2 stay until 6 meh...	0
+3775	its ok my arm is feeling weak cuz i got a shot so we can go another time	0
+3776	me not waking up until 4 in the afternoon, sup	0
+3777	\me 2 babe i feel the same lets just 4get about it+both try +cheer up+not fit soo muchxxlove u locaxx\""	0
+3778	1apple/day=no doctor. 1tulsi leaf/day=no cancer. 1lemon/day=no fat. 1cup milk/day=no bone problms 3 litres watr/day=no diseases snd ths 2 whom u care..:-)	0
+3779	haha get used to driving to usf man, i know a lot of stoners	0
+3780	ok no prob. take ur time.	0
+3781	aight, let me know when you're gonna be around usf	0
+3782	bring it if you got it	0
+3783	i'm done oredi...	0
+3784	still i have not checked it da. . .	0
+3785	if i said anything wrong sorry de:-)	0
+3786	well, i have to leave for my class babe ... you never came back to me ... :-( ... hope you have a nice sleep, my love	0
+3787	guess who spent all last night phasing in and out of the fourth dimension	0
+3788	how many licks does it take to get to the center of a tootsie pop?	0
+3789	open rebtel with firefox. when it loads just put plus sign in the user name place, and it will show you two numbers. the lower number is my number. once you pick that number the pin will display okay!	0
+3790	your pussy is perfect!	0
+3791	lol i would but despite these cramps i like being a girl.	0
+3792	loans for any purpose even if you have bad credit! tenants welcome. call noworriesloans.com on 08717111821	1
+3793	sorry, i'll call later	0
+3794	did u got that persons story	0
+3795	babe ? i lost you ... :-(	0
+3796	ok... the theory test? when are ?_ going to book? i think it's on 21 may. coz thought wanna go out with jiayin. but she isnt free	0
+3797	spoke with uncle john today. he strongly feels that you need to sacrifice to keep me here. he's going to call you. when he does, i beg you to just listen. dont make any promises or make it clear things are not easy. and i need you to please let us work things out. as long as i keep expecting help, my creativity will be stifled so pls just keep him happy, no promises on your part.	0
+3798	promotion number: 8714714 - ur awarded a city break and could win a ??200 summer shopping spree every wk. txt store to 88039 . skilgme. tscs087147403231winawk!age16 ??1.50perwksub	1
+3799	can u look 4 me in da lib i got stuff havent finish yet.	0
+3800	hello madam how are you ?	0
+3801	whenevr ur sad, whenevr ur gray, remembr im here 2 listn 2 watevr u wanna say, jus walk wid me a little while,&amp; i promise i'll bring back ur smile.:-)	0
+3802	how do you plan to manage that	0
+3803	the guy did some bitching but i acted like i'd be interested in buying something else next week and he gave it to us for free	0
+3804	\alrite hunny!wot u up 2 2nite? didnt end up goin down town jus da pub instead! jus chillin at da mo in me bedroom!love jen xxx.\""	0
+3805	ya just telling abt tht incident..	0
+3806	no but the bluray player can	0
+3807	hello peach! my cake tasts lush!	0
+3808	valentines day special! win over ??1000 in our quiz and take your partner on the trip of a lifetime! send go to 83600 now. 150p/msg rcvd. custcare:08718720201.	1
+3809	awesome question with a cute answer: someone asked a boy \how is ur life?\" . . he smiled &amp; answered: . . \"she is fine!\" gudnite"	0
+3810	sorry, i'll call later	0
+3811	studying. but i.ll be free next weekend.	0
+3812	right it wasnt you who phoned it was someone with a number like yours!	0
+3813	want explicit sex in 30 secs? ring 02073162414 now! costs 20p/min	1
+3814	okey doke. i'm at home, but not dressed cos laying around ill! speak to you later bout times and stuff.	0
+3815	dont think you need yellow card for uk travel. ask someone that has gone before. if you do its just  &lt;#&gt; bucks	0
+3816	also sir, i sent you an email about how to log into the usc payment portal. i.ll send you another message that should explain how things are back home. have a great weekend.	0
+3817	dear, will call tmorrow.pls accomodate.	0
+3818	my no. in luton 0125698789 ring me if ur around! h*	0
+3819	joy's father is john. then john is the ____ of joy's father. if u ans ths you hav  &lt;#&gt;  iq. tis s ias question try to answer.	0
+3820	can you let me know details of fri when u find out cos i'm not in tom or fri. mentionned chinese. thanks	0
+3821	hello! good week? fancy a drink or something later?	0
+3822	helloooo... wake up..! \sweet\" \"morning\" \"welcomes\" \"you\" \"enjoy\" \"this day\" \"with full of joy\".. \"gud mrng\"."	0
+3823	**free message**thanks for using the auction subscription service. 18 . 150p/msgrcvd 2 skip an auction txt out. 2 unsubscribe txt stop customercare 08718726270	1
+3824	<forwarded from 448712404000>please call 08712404000 immediately as there is an urgent message waiting for you.	1
+3825	how are u? i have missed u! i havent been up 2 much a bit bored with the holiday want 2 go bak 2 college! sad isnt it?xx	0
+3826	i can't believe how attached i am to seeing you every day. i know you will do the best you can to get to me babe. i will go to teach my class at your midnight	0
+3827	ya i knw u vl giv..its ok thanks kano..anyway enjoy wit ur family wit 1st salary..:-);-)	0
+3828	yes.. now only saw your message..	0
+3829	u say leh... of course nothing happen lar. not say v romantic jus a bit only lor. i thk e nite scenery not so nice leh.	0
+3830	then just eat a shit and wait for ur monkey face bitch.......... u asshole..................	0
+3831	feb  &lt;#&gt;  is \i love u\" day. send dis to all ur \"valued frnds\" evn me. if 3 comes back u'll gt married d person u luv! if u ignore dis u will lose ur luv 4 evr"	0
+3832	i had a good time too. its nice to do something a bit different with my weekends for a change. see ya soon	0
+3833	urgent! we are trying to contact u. todays draw shows that you have won a ??2000 prize guaranteed. call 09066358361 from land line. claim y87. valid 12hrs only	1
+3834	that??s alrite girl, u know gail is neva wrong!!take care sweet and don??t worry.c u l8tr hun!love yaxxx	0
+3835	hiya, probably coming home * weekend after next	0
+3836	awesome, text me when you're restocked	0
+3837	missed call alert. these numbers called but left no message. 07008009200	1
+3838	just do what ever is easier for you	0
+3839	otherwise had part time job na-tuition..	0
+3840	are you angry with me. what happen dear	0
+3841	honestly i've just made a lovely cup of tea and promptly dropped my keys in it and then burnt my fingers getting them out!	0
+3842	thanx. yup we coming back on sun. finish dinner going back 2 hotel now. time flies, we're tog 4 exactly a mth today. hope we'll haf many more mths to come...	0
+3843	go until jurong point, crazy.. available only in bugis n great world la e buffet... cine there got amore wat...	0
+3844	are you driving or training?	0
+3845	hello lover! how goes that new job? are you there now? are you happy? do you think of me? i wake, my slave and send you a teasing kiss from across the sea	0
+3846	send his number and give reply tomorrow morning for why you said that to him like that ok	0
+3847	this phone has the weirdest auto correct.	0
+3848	so what u doing today?	0
+3849	so wat's da decision?	0
+3850	free any day but i finish at 6 on mon n thurs...	0
+3851	is there any movie theatre i can go to and watch unlimited movies and just pay once?	0
+3852	sir, waiting for your mail.	0
+3853	hi hope u r both ok, he said he would text and he hasn't, have u seen him, let me down gently please	0
+3854	i don't quite know what to do. i still can't get hold of anyone. i cud pick you up bout 7.30pm and we can see if they're in the pub?	0
+3855	i'm home.	0
+3856	why must we sit around and wait for summer days to celebrate. such a magical sight when the worlds dressed in white. oooooh let there be snow.	0
+3857	will be office around 4 pm. now i am going hospital.	0
+3858	one day a crab was running on the sea shore..the waves came n cleared the footprints of the crab.. crab asked: being my frnd y r u clearing my beautiful footprints? waves replied: a fox was following ur footprints to catch you! thats y i cleared it off:) frndsship never lets u dwn :-) gud nyt..	0
+3859	also fuck you and your family for going to rhode island or wherever the fuck and leaving me all alone the week i have a new bong &gt;:(	0
+3860	not planned yet :)going to join company on jan 5 only.don know what will happen after that.	0
+3861	she left it very vague. she just said she would inform the person in accounting about the delayed rent and that i should discuss with the housing agency about my renting another place. but checking online now and all places around usc are  &lt;#&gt;  and up	0
+3862	just buy a pizza. meat lovers or supreme. u get to pick.	0
+3863	had your mobile 10 mths? update to the latest camera/video phones for free. keep ur same number, get extra free mins/texts. text yes for a call	1
+3864	rofl betta invest in some  anti aging products	0
+3865	ha... both of us doing e same thing. but i got tv 2 watch. u can thk of where 2 go tonight or u already haf smth in mind...	0
+3866	thats a bit weird, even ?- where is the do supposed to be happening? but good idea, sure they will be in pub!	0
+3867	ok... take ur time n enjoy ur dinner...	0
+3868	oh ya ya. i remember da. .	0
+3869	no my mum went 2 dentist.	0
+3870	no * am working on the ringing u thing but have whole houseful of screaming brats so * am pulling my hair out! loving u	0
+3871	i cant pick the phone right now. pls send a message	0
+3872	yes..he is really great..bhaji told kallis best cricketer after sachin in world:).very tough to get out.	0
+3873	hi.what you think about match?	0
+3874	free entry into our ??250 weekly comp just send the word win to 80086 now. 18 t&c www.txttowin.co.uk	1
+3875	.please charge my mobile when you get up in morning.	0
+3876	1's finish meeting call me.	0
+3877	yay! finally lol. i missed our cinema trip last week :-(	0
+3878	missing you too.pray inshah allah	0
+3879	now i'm going for lunch.	0
+3880	k then 2marrow are you coming to class.	0
+3881	god created gap btwn ur fingers so dat sum1 vry special will fill those gaps by holding ur hands.. now plz dont ask y he created so much gap between legs !!!	0
+3882	you are being contacted by our dating service by someone you know! to find out who it is, call from a land line 09050000928. pobox45w2tg150p	1
+3883	i'm used to it. i just hope my agents don't drop me since i've only booked a few things this year. this whole me in boston, them in nyc was an experiment.	0
+3884	hi this is amy, we will be sending you a free phone number in a couple of days, which will give you an access to all the adult parties...	1
+3885	arr birthday today:) i wish him to get more oscar.	0
+3886	how much is torch in 9ja.	0
+3887	still chance there. if you search hard you will get it..let have a try :)	0
+3888	i tot u outside cos darren say u come shopping. of course we nice wat. we jus went sim lim look at mp3 player.	0
+3889	forgot to tell ?_ smth.. can ?_ like number the sections so that it's clearer..	0
+3890	s.i'm watching it in live..	0
+3891	can't take any major roles in community outreach. you rock mel	0
+3892	it's ok i noe u're busy but i'm really too bored so i msg u. i oso dunno wat colour she choose 4 me one.	0
+3893	is that on the telly? no its brdget jones!	0
+3894	is there a reason we've not spoken this year? anyways have a great week and all the best in your exam	0
+3895	carlos is down but i have to pick it up from him, so i'll swing by usf in a little bit	0
+3896	it does it on its own. most of the time it fixes my spelling. but sometimes it gets a completely diff word. go figure	0
+3897	you ve won! your 4* costa del sol holiday or ??5000 await collection. call 09050090044 now toclaim. sae, tc s, pobox334, stockport, sk38xh, cost??1.50/pm, max10mins	1
+3898	hey boys. want hot xxx pics sent direct 2 ur phone? txt porn to 69855, 24hrs free and then just 50p per day. to stop text stopbcm sf wc1n3xx	1
+3899	oh... kay... on sat right?	0
+3900	are you being good, baby? :)	0
+3901	did either of you have any idea's? do you know of anyplaces doing something?	0
+3902	you call him and tell now infront of them. call him now.	0
+3903	about  &lt;#&gt; bucks. the banks fees are fixed. better to call the bank and find out.	0
+3904	sms auction - a brand new nokia 7250 is up 4 auction today! auction is free 2 join & take part! txt nokia to 86021 now!	1
+3905	oops. 4 got that bit.	0
+3906	those ducking chinchillas	0
+3907	are your freezing ? are you home yet ? will you remember to kiss your mom in the morning? do you love me ? do you think of me ? are you missing me yet ?	0
+3908	aight text me when you're back at mu and i'll swing by, need somebody to get the door for me	0
+3909	japanese proverb: if one can do it, u too can do it, if none can do it,u must do it indian version: if one can do it, let him do it.. if none can do it,leave it!! and finally kerala version: if one can do it, stop him doing it.. if none can do it, make a strike against it ...	0
+3910	sorry i missed you babe. i was up late and slept in. i hope you enjoy your driving lesson, boytoy. i miss you too ... *teasing kiss*	0
+3911	i hope you that's the result of being consistently intelligent and kind. start asking him about practicum links and keep your ears open and all the best. ttyl	0
+3912	well welp is sort of a semiobscure internet thing	0
+3913	ok....take care.umma to you too...	0
+3914	are you this much buzy	0
+3915	tessy..pls do me a favor. pls convey my birthday wishes to nimya..pls dnt forget it. today is her birthday shijas	0
+3916	k.i will send in  &lt;#&gt;  min:)	0
+3917	already one guy loving you:-.	0
+3918	staff.science.nus.edu.sg/~phyhcmk/teaching/pc1323	0
+3919	new theory: argument wins d situation, but loses the person. so dont argue with ur friends just.. . . . kick them &amp; say, i'm always correct.!	0
+3920	they released vday shirts and when u put it on it makes your bottom half naked instead of those white underwear.	0
+3921	if you have belive me. come to my home.	0
+3922	hey mate! hows u honey?did u ave good holiday? gimmi de goss!x	0
+3923	turns out my friends are staying for the whole show and won't be back til ~ &lt;#&gt; , so feel free to go ahead and smoke that $ &lt;#&gt;  worth	0
+3924	k k pa had your lunch aha.	0
+3925	yup it's at paragon... i havent decided whether 2 cut yet... hee...	0
+3926	you are a winner u have been specially selected 2 receive ??1000 cash or a 4* holiday (flights inc) speak to a live operator 2 claim 0871277810810	1
+3927	we have new local dates in your area - lots of new people registered in your area. reply date to start now! 18 only www.flirtparty.us replys150	1
+3928	are you not around or just still asleep? :v	0
+3929	very hurting n meaningful lines ever: \i compromised everything for my love	0
+3930	i wonder if your phone battery went dead ? i had to tell you, i love you babe	0
+3931	no dear i do have free messages without any recharge. hi hi hi	0
+3932	not directly behind... abt 4 rows behind ?_...	0
+3933	lol ok. i'll snatch her purse too.	0
+3934	say this slowly.? god,i love you &amp; i need you,clean my heart with your blood.send this to ten special people &amp; u c miracle tomorrow, do it,pls,pls do it...	0
+3935	i wonder if you'll get this text?	0
+3936	thanks da thangam, i feel very very happy dear. i also miss you da.	0
+3937	hey i will be late ah... meet you at 945+	0
+3938	haha... they cant what... at the most tmr forfeit... haha so how?	0
+3939	oh k:)after that placement there ah?	0
+3940	k fyi x has a ride early tomorrow morning but he's crashing at our place tonight	0
+3941	sunshine quiz wkly q! win a top sony dvd player if u know which country the algarve is in? txt ansr to 82277. ??1.50 sp:tyrone	1
+3942	accordingly. i repeat, just text the word ok on your mobile phone and send	1
+3943	hello. no news on job, they are making me wait a fifth week! yeah im up for some woozles and weasels... in exeter still, but be home about 3.	0
+3944	ringtone club: gr8 new polys direct to your mobile every week !	1
+3945	what time is ur flight tmr?	0
+3946	boooo you always work. just quit.	0
+3947	mum not going robinson already.	0
+3948	congratulations - thanks to a good friend u have won the ??2,000 xmas prize. 2 claim is easy, just call 08712103738 now! only 10p per minute. bt-national-rate	1
+3949	great news! call freefone 08006344447 to claim your guaranteed ??1000 cash or ??2000 gift. speak to a live operator now!	1
+3950	i don,t think so. you don't need to be going out that late on a school night. especially when the one class you have is the one you missed last wednesday and probably failed a test in on friday	0
+3951	i jus reached home. i go bathe first. but my sis using net tell u when she finishes k...	0
+3952	thats cool! i am a gentleman and will treat you with dignity and respect.	0
+3953	just chill for another 6hrs. if you could sleep the pain is not a surgical emergency so see how it unfolds. okay	0
+3954	also tell him i said happy birthday	0
+3955	shit that is really shocking and scary, cant imagine for a second. def up for night out. do u think there is somewhere i could crash for night, save on taxi?	0
+3956	just sent again. do you scream and moan in bed, princess?	0
+3957	and stop wondering \wow is she ever going to stop tm'ing me ?!\" because i will tm you whenever i want because you are mine ... *laughs*"	0
+3958	i want to grasp your pretty booty :)	0
+3959	i remain unconvinced that this isn't an elaborate test of my willpower	0
+3960	like  &lt;#&gt; , same question	0
+3961	tone club: your subs has now expired 2 re-sub reply monoc 4 monos or polyc 4 polys 1 weekly @ 150p per week txt stop 2 stop this msg free stream 0871212025016	1
+3962	i am thinking of going down to reg for pract lessons.. flung my advance.. haha wat time u going?	0
+3963	18 days to euro2004 kickoff! u will be kept informed of all the latest news and results daily. unsubscribe send get euro stop to 83222.	1
+3964	urgent! your mobile no was awarded a ??2,000 bonus caller prize on 1/08/03! this is our 2nd attempt to contact you! call 0871-4719-523 box95qu bt national rate	1
+3965	k, can i pick up another 8th when you're done?	0
+3966	nothin comes to my mind. ?? help me buy hanger lor. ur laptop not heavy?	0
+3967	convey my regards to him	0
+3968	so how are you really. what are you up to. how's the masters. and so on.	0
+3969	storming msg: wen u lift d phne, u say \hello\" do u knw wt is d real meaning of hello?? . . . it's d name of a girl..! . . . yes.. and u knw who is dat girl?? \"margaret hello\" she is d girlfrnd f grahmbell who invnted telphone... . . . . moral:one can 4get d name of a person	0
+3970	when the first strike is a red one. the bird + antelope begin toplay in the fieldof selfindependence believe this + the flower of contention will grow.random!	0
+3971	purity of friendship between two is not about smiling after reading the forwarded message..its about smiling just by seeing the name. gud evng musthu	0
+3972	ok no prob... i'll come after lunch then...	0
+3973	how are you doing? hope you've settled in for the new school year. just wishin you a gr8 day	0
+3974	y she dun believe leh? i tot i told her it's true already. i thk she muz c us tog then she believe.	0
+3975	is there any training tomorrow?	0
+3976	jos ask if u wana meet up?	0
+3977	1) go to write msg 2) put on dictionary mode 3)cover the screen with hand, 4)press  &lt;#&gt; . 5)gently remove ur hand.. its interesting..:)	0
+3978	good afternoon sexy buns! how goes the job search ? i wake and you are my first thought as always, my love. i wish your fine and happy and know i adore you!	0
+3979	we made it! eta at taunton is 12:30 as planned, hope that???s still okday?! good to see you! :-xx	0
+3980	hi its lucy hubby at meetins all day fri & i will b alone at hotel u fancy cumin over? pls leave msg 2day 09099726395 lucy x calls??1/minmobsmorelkpobox177hp51fl	1
+3981	damn, poor zac doesn't stand a chance	0
+3982	will do. have a good day	0
+3983	mum say we wan to go then go... then she can shun bian watch da glass exhibition...	0
+3984	how come i din c ?_... yup i cut my hair...	0
+3985	how come she can get it? should b quite diff to guess rite...	0
+3986	surely result will offer:)	0
+3987	cool, we shall go and see, have to go to tip anyway. are you at home, got something to drop in later? so lets go to town tonight! maybe mum can take us in.	0
+3988	which channel:-):-):):-).	0
+3989	for taking part in our mobile survey yesterday! you can now have 500 texts 2 use however you wish. 2 get txts just send txt to 80160 t&c www.txt43.com 1.50p	1
+3990	tell me they're female :v how're you throwing in? we're deciding what all to get now	0
+3991	slaaaaave ! where are you ? must i summon you to me all the time now ? don't you wish to come to me on your own anymore?	0
+3992	there is a first time for everything :)	0
+3993	they will pick up and drop in car.so no problem..	0
+3994	alright, we're all set here, text the man	0
+3995	awesome, how do i deal with the gate? charles told me last night but, uh, yeah	0
+3996	take us out shopping and mark will distract isaiah.=d	0
+3997	have you seen who's back at holby?!	0
+3998	so wats ur opinion abt him and how abt is character?	0
+3999	bloody hell, cant believe you forgot my surname mr . ill give u a clue, its spanish and begins with m...	0
+4000	still chance there. if you search hard you will get it..let have a try :)	0
+4001	this is the 2nd time we have tried 2 contact u. u have won the 750 pound prize. 2 claim is easy, call 08712101358 now! only 10p per min. bt-national-rate	1
+4002	my trip was ok but quite tiring lor. uni starts today but it's ok 4 me cos i'm not taking any modules but jus concentrating on my final yr project.	0
+4003	yesterday its with me only . now am going home.	0
+4004	ups which is 3days also, and the shipping company that takes 2wks. the other way is usps which takes a week but when it gets to lag you may have to bribe nipost to get your stuff.	0
+4005	hi! you just spoke to maneesha v. we'd like to know if you were satisfied with the experience. reply toll free with yes or no.	0
+4006	sorry, my battery died, i can come by but i'm only getting a gram for now, where's your place?	0
+4007	company is very good.environment is terrific and food is really nice:)	0
+4008	doing my masters. when will you buy a bb cos i have for sale and how's bf	0
+4009	do i? i thought i put it back in the box	0
+4010	recpt 1/3. you have ordered a ringtone. your order is being processed...	1
+4011	kent vale lor... ?? wait 4 me there ar?	0
+4012	rt-king pro video club>> need help? info@ringtoneking.co.uk or call 08701237397 you must be 16+ club credits redeemable at www.ringtoneking.co.uk! enjoy!	1
+4013	no screaming means shouting..	0
+4014	so now my dad is gonna call after he gets out of work and ask all these crazy questions.	0
+4015	i am taking you for italian food. how about a pretty dress with no panties? :)	0
+4016	single line with a big meaning::::: \miss anything 4 ur \"best life\" but	0
+4017	whats the staff name who is taking class for us?	0
+4018	that's fine, have him give me a call if he knows what he wants or has any questions	0
+4019	lol great now im getting hungry.	0
+4020	this is the 2nd time we have tried 2 contact u. u have won the ??750 pound prize. 2 claim is easy, call 087187272008 now1! only 10p per minute. bt-national-rate.	1
+4021	i come n pick ?_ up... come out immediately aft ur lesson...	0
+4022	you won't believe it but it's true. it's incredible txts! reply g now to learn truly amazing things that will blow your mind. from o2fwd only 18p/txt	1
+4023	haha better late than ever, any way i could swing by?	0
+4024	no need lar. jus testing e phone card. dunno network not gd i thk. me waiting 4 my sis 2 finish bathing so i can bathe. dun disturb u liao u cleaning ur room.	0
+4025	how are you with money...as in to you...money aint a thing....how are you sha!	0
+4026	huh so late... fr dinner?	0
+4027	jus ans me lar. u'll noe later.	0
+4028	can i get your opinion on something first?	0
+4029	cheers lou! yeah was a goodnite shame u neva came! c ya gailxx	0
+4030	have a good evening! ttyl	0
+4031	infact happy new year. how are you where are you when are we seeing	0
+4032	you will recieve your tone within the next 24hrs. for terms and conditions please see channel u teletext pg 750	1
+4033	raviyog peripherals bhayandar east	0
+4034	looks like you found something to do other than smoke, great job!	0
+4035	i got like $ &lt;#&gt; , i can get some more later though. get whatever you feel like	0
+4036	you 07801543489 are guaranteed the latests nokia phone, a 40gb ipod mp3 player or a ??500 prize! txt word:collect to no:83355! tc-llc ny-usa 150p/mt msgrcvd18+	1
+4037	just gettin a bit arty with my collages at the mo, well tryin 2 ne way! got a roast in a min lovely i shall enjoy that!	0
+4038	i went to project centre	0
+4039	i don't have anybody's number, i still haven't thought up a tactful way to ask alex	0
+4040	found it, enc  &lt;#&gt; , where you at?	0
+4041	i promise to take good care of you, princess. i have to run now. please send pics when you get a chance. ttyl!	0
+4042	when should i come over?	0
+4043	k i'm ready,  &lt;#&gt; ?	0
+4044	yes. please leave at  &lt;#&gt; . so that at  &lt;#&gt;  we can leave	0
+4045	buzz! hey, my love ! i think of you and hope your day goes well. did you sleep in ? i miss you babe. i long for the moment we are together again*loving smile*	0
+4046	u have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09058094594	1
+4047	hey mr  and i are going to the sea view and having a couple of gays i mean games! give me a bell when ya finish	0
+4048	sweet, we may or may not go to 4u to meet carlos so gauge patty's interest in that	0
+4049	yup no more already... thanx 4 printing n handing it up.	0
+4050	set a place for me in your heart and not in your mind, as the mind easily forgets but the heart will always remember. wish you happy valentines day!	0
+4051	i wonder how you got online, my love ? had you gone to the net cafe ? did you get your phone recharged ? were you on a friends net ? i think of you, boytoy	0
+4052	then any special there?	0
+4053	hi babe im at home now wanna do something? xx	0
+4054	fyi i'm at usf now, swing by the room whenever	0
+4055	loan for any purpose ??500 - ??75,000. homeowners + tenants welcome. have you been previously refused? we can still help. call free 0800 1956669 or text back 'help'	1
+4056	sure but since my parents will be working on tuesday i don't really need a cover story	0
+4057	at 4. let's go to bill millers	0
+4058	alright. i'm out--have a good night!	0
+4059	ha... then we must walk to everywhere... cannot take tram. my cousin said can walk to vic market from our hotel	0
+4060	aight, i'll ask a few of my roommates	0
+4061	now project pa. after that only i can come.	0
+4062	just sleeping..and surfing	0
+4063	2/2 146tf150p	1
+4064	we'll join the  &lt;#&gt;  bus	0
+4065	ok	0
+4066	sorry, no, have got few things to do. may be in pub later.	0
+4067	how r ?_ going to send it to me?	0
+4068	?? no home work to do meh...	0
+4069	you still coming tonight?	0
+4070	dude just saw a parked car with its sunroof popped up. sux	0
+4071	depends on individual lor e hair dresser say pretty but my parents say look gong. u kaypoh.. i also dunno wat she collecting.	0
+4072	yeah we do totes. when u wanna?	0
+4073	i can??t wait for cornwall. hope tonight isn??t too bad as well but it??s rock night shite. anyway i??m going for a kip now have a good night. speak to you soon.	0
+4074	how would my ip address test that considering my computer isn't a minecraft server	0
+4075	haha okay... today weekend leh...	0
+4076	freemsg why haven't you replied to my text? i'm randy, sexy, female and live local. luv to hear from u. netcollex ltd 08700621170150p per msg reply stop to end	1
+4077	07732584351 - rodger burns - msg = we tried to call you re your reply to our sms for a free nokia mobile + free camcorder. please call now 08000930705 for delivery tomorrow	1
+4078	thought praps you meant another one. goodo! i'll look tomorrow	0
+4079	well there's not a lot of things happening in lindsay on new years *sighs* some bars in ptbo and the blue heron has something going	0
+4080	this is the 2nd time we have tried to contact u. u have won the ??400 prize. 2 claim is easy, just call 087104711148 now! only 10p per minute. bt-national-rate	1
+4081	ok then u tell me wat time u coming later lor.	0
+4082	purity of friendship between two is not about smiling after reading the forwarded message..its about smiling just by seeing the name. gud evng musthu	0
+4083	unni thank you dear for the recharge..rakhesh	0
+4084	dear i have reache room	0
+4085	please call 08712402578 immediately as there is an urgent message waiting for you	1
+4086	lol grr my mom is taking forever with my prescription. pharmacy is like 2 minutes away. ugh.	0
+4087	i like to think there's always the possibility of being in a pub later.	0
+4088	knock knock txt whose there to 80082 to enter r weekly draw 4 a ??250 gift voucher 4 a store of yr choice. t&cs www.tkls.com age16 to stoptxtstop??1.50/week	1
+4089	with my sis lor... we juz watched italian job.	0
+4090	we're finally ready fyi	0
+4091	what time u wrkin?	0
+4092	i know girls always safe and selfish know i got it pa. thank you. good night.	0
+4093	is xy in ur car when u picking me up?	0
+4094	ok thanx...	0
+4095	win urgent! your mobile number has been awarded with a ??2000 prize guaranteed call 09061790121 from land line. claim 3030 valid 12hrs only 150ppm	1
+4096	i.ll hand her my phone to chat wit u	0
+4097	yes :)it completely in out of form:)clark also utter waste.	0
+4098	married local women looking for discreet action now! 5 real matches instantly to your phone. text match to 69969 msg cost 150p 2 stop txt stop bcmsfwc1n3xx	1
+4099	sorry to be a pain. is it ok if we meet another night? i spent late afternoon in casualty and that means i haven't done any of y stuff42moro and that includes all my time sheets and that. sorry.	0
+4100	todays vodafone numbers ending with 4882 are selected to a receive a ??350 award. if your number matches call 09064019014 to receive your ??350 award.	1
+4101	it has issues right now. ill fix for her by tomorrow.	0
+4102	how are you? i miss you!	0
+4103	k..k..i'm also fine:)when will you complete the course?	0
+4104	there bold 2  &lt;#&gt; . is that yours	0
+4105	cool, text me when you're parked	0
+4106	new mobiles from 2004, must go! txt: nokia to no: 89545 & collect yours today! from only ??1. www.4-tc.biz 2optout 087187262701.50gbp/mtmsg18 txtauction.	1
+4107	i called and said all to him:)then he have to choose this future.	0
+4108	hey come online! use msn... we are all there	0
+4109	ard 6 like dat lor.	0
+4110	he remains a bro amongst bros	0
+4111	buy one egg for me da..please:)	0
+4112	takin a shower now but yeah i'll leave when i'm done	0
+4113	urgent -call 09066649731from landline. your complimentary 4* ibiza holiday or ??10,000 cash await collection sae t&cs po box 434 sk3 8wp 150ppm 18+	1
+4114	leave it de:-). start prepare for next:-)..	0
+4115	do whatever you want. you know what the rules are. we had a talk earlier this week about what had to start happening, you showing responsibility. yet, every week it's can i bend the rule this way? what about that way? do whatever. i'm tired of having thia same argument with you every week. and a  &lt;#&gt;  movie doesnt inlude the previews. you're still getting in after 1.	0
+4116	wait  &lt;#&gt;  min..	0
+4117	dont worry. i guess he's busy.	0
+4118	you are sweet as well, princess. please tell me your likes and dislikes in bed...	0
+4119	what is this 'hex' place you talk of? explain!	0
+4120	to the wonderful okors, have a great month. we cherish you guys and wish you well each day. mojibiola	0
+4121	why nothing. ok anyway give me treat	0
+4122	i love u 2 my little pocy bell i am sorry but i love u	0
+4123	a cute thought for friendship: \its not necessary to share every secret with ur close frnd	0
+4124	tell dear what happen to you. why you talking to me like an alian	0
+4125	no calls..messages..missed calls	0
+4126	mostly sports type..lyk footbl,crckt..	0
+4127	wake me up at  &lt;#&gt;  am morning:)	0
+4128	yes just finished watching days of our lives. i love it.	0
+4129	hello darling how are you today? i would love to have a chat, why dont you tell me what you look like and what you are in to sexy?	1
+4130	for real when u getting on yo? i only need 2 more tickets and one more jacket and i'm done. i already used all my multis.	0
+4131	ok that's great thanx a lot.	0
+4132	hey i've booked the 2 lessons on sun liao...	0
+4133	hope you enjoyed your new content. text stop to 61610 to unsubscribe. help:08712400602450p provided by tones2you.co.uk	1
+4134	ok lor... or u wan me go look 4 u?	0
+4135	you are being contacted by our dating service by someone you know! to find out who it is, call from your mobile or landline 09064017305 pobox75ldns7	1
+4136	probably a couple hours tops	0
+4137	hello handsome ! are you finding that job ? not being lazy ? working towards getting back that net for mummy ? where's my boytoy now ? does he miss me ?	0
+4138	all these nice new shirts and the only thing i can wear them to is nudist themed ;_; you in mu?	0
+4139	each moment in a day,has its own value-morning brings hope,afternoon brings faith,evening brings luv,night brings rest,wish u find them all today.good morning	0
+4140	id onluy matters when getting on from offcampus	0
+4141	talk sexy!! make new friends or fall in love in the worlds most discreet text dating service. just text vip to 83110 and see who you could meet.	1
+4142	winner!! as a valued network customer you have been selected to receivea ??900 prize reward! to claim call 09061701461. claim code kl341. valid 12 hours only.	1
+4143	miss ya, need ya, want ya, love ya.	0
+4144	merry christmas to you too babe, i love ya *kisses*	0
+4145	i am late. i will be there at	0
+4146	i dont have i shall buy one dear	0
+4147	are you wet right now?	0
+4148	sms. ac sun0819 posts hello:\you seem cool	1
+4149	sweet heart how are you?	0
+4150	its good, we'll find a way	0
+4151	he said that he had a right giggle when he saw u again! you would possibly be the first person2die from nvq, but think how much you could for!	0
+4152	yar lor wait 4 my mum 2 finish sch then have lunch lor... i whole morning stay at home clean my room now my room quite clean... hee...	0
+4153	someone u know has asked our dating service 2 contact you! cant guess who? call 09058091854 now all will be revealed. po box385 m6 6wu	1
+4154	it's wylie, you in tampa or sarasota?	0
+4155	i dont know ask to my brother. nothing problem some thing that. just i told .	0
+4156	darren was saying dat if u meeting da ge den we dun meet 4 dinner. cos later u leave xy will feel awkward. den u meet him 4 lunch lor.	0
+4157	ok can...	0
+4158	good evening sir, hope you are having a nice day. i wanted to bring it to your notice that i have been late in paying rent for the past few months and have had to pay a $ &lt;#&gt;  charge. i felt it would be inconsiderate of me to nag about something you give at great cost to yourself and that's why i didnt speak up. i however am in a recession and wont be able to pay the charge this month hence my askin well ahead of month's end. can you please help. thank you for everything.	0
+4159	received, understood n acted upon!	0
+4160	okay, good, no problem, and thanx!	0
+4161	ok	0
+4162	eastenders tv quiz. what flower does dot compare herself to? d= violet e= tulip f= lily txt d e or f to 84025 now 4 chance 2 win ??100 cash wkent/150p16+	1
+4163	come around  &lt;decimal&gt; pm vikky..i'm otside nw, il come by tht time	0
+4164	u have a secret admirer. reveal who thinks u r so special. call 09065174042. to opt out reply reveal stop. 1.50 per msg recd. cust care 07821230901	1
+4165	pls send me your address sir.	0
+4166	i absolutely love south park! i only recently started watching the office.	0
+4167	hmmm.... mayb can try e shoppin area one, but forgot e name of hotel...	0
+4168	you are gorgeous! keep those pix cumming :) thank you!	0
+4169	she went to attend another two rounds today..but still did't reach home..	0
+4170	ok. but i finish at 6.	0
+4171	pls give her the food preferably pap very slowly with loads of sugar. you can take up to an hour to give it. and then some water. very very slowly.	0
+4172	no just send to you. bec you in temple na.	0
+4173	gam gone after outstanding innings.	0
+4174	you are being contacted by our dating service by someone you know! to find out who it is, call from a land line 09050000878. pobox45w2tg150p	1
+4175	bloomberg -message center +447797706009 why wait? apply for your future http://careers. bloomberg.com	1
+4176	you have an important customer service announcement. call freephone 0800 542 0825 now!	1
+4177	k...k...when will you give treat?	0
+4178	dunno lei... i thk mum lazy to go out... i neva ask her yet...	0
+4179	freemsg hey there darling it's been 3 week's now and no word back! i'd like some fun you up for it still? tb ok! xxx std chgs to send, ??1.50 to rcv	1
+4180	ok...	0
+4181	haf u found him? i feel so stupid da v cam was working.	0
+4182	hello! how's you and how did saturday go? i was just texting to see if you'd decided to do anything tomo. not that i'm trying to invite myself or anything!	0
+4183	\symptoms\" when u are in love: \"1.u like listening songs 2.u get stopped where u see the name of your beloved 3.u won't get angry when your"	0
+4184	i did. one slice and one breadstick. lol	0
+4185	thanks for your ringtone order, ref number k718. your mobile will be charged ??4.50. should your tone not arrive please call customer services on 09065069120	1
+4186	hey you can pay. with salary de. only  &lt;#&gt; .	0
+4187	if u laugh really loud.. if u talk spontaneously.. if u dont care what others feel.. u are probably with your dear &amp; best friends.. goodevening dear..:)	0
+4188	you know there is. i shall speak to you in  &lt;#&gt;  minutes then	0
+4189	also hi wesley how've you been	0
+4190	eastenders tv quiz. what flower does dot compare herself to? d= violet e= tulip f= lily txt d e or f to 84025 now 4 chance 2 win ??100 cash wkent/150p16+	1
+4191	why you keeping me away like this	0
+4192	that's ok. i popped in to ask bout something and she said you'd been in. are you around tonght wen this girl comes?	0
+4193	and several to you sir.	0
+4194	depends on quality. if you want the type i sent boye, faded glory, then about 6. if you want ralphs maybe 2	0
+4195	sday only joined.so training we started today:)	0
+4196	u will switch your fone on dammit!!	0
+4197	lil fever:) now fine:)	0
+4198	be happy there. i will come after noon	0
+4199	as per your request 'melle melle (oru minnaminunginte nurungu vettam)' has been set as your callertune for all callers. press *9 to copy your friends callertune	0
+4200	get the official england poly ringtone or colour flag on yer mobile for tonights game! text tone or flag to 84199. optout txt eng stop box39822 w111wx ??1.50	1
+4201	i'm e person who's doing e sms survey...	0
+4202	come to mahal bus stop.. &lt;decimal&gt;	0
+4203	marvel mobile play the official ultimate spider-man game (??4.50) on ur mobile right now. text spider to 83338 for the game & we ll send u a free 8ball wallpaper	1
+4204	yup song bro. no creative. neva test quality. he said check review online.	0
+4205	nite...	0
+4206	hahaha..use your brain dear	0
+4207	free2day sexy st george's day pic of jordan!txt pic to 89080 dont miss out, then every wk a saucy celeb!4 more pics c pocketbabe.co.uk 0870241182716 ??3/wk	1
+4208	haha, my legs and neck are killing me and my amigos are hoping to end the night with a burn, think i could swing by in like an hour?	0
+4209	lol you forgot it eh ? yes, i'll bring it in babe	0
+4210	yup bathe liao...	0
+4211	loan for any purpose ??500 - ??75,000. homeowners + tenants welcome. have you been previously refused? we can still help. call free 0800 1956669 or text back 'help'	1
+4212	ok lor but not too early. me still having project meeting now.	0
+4213	white fudge oreos are in stores	0
+4214	die... i accidentally deleted e msg i suppose 2 put in e sim archive. haiz... i so sad...	0
+4215	omg you can make a wedding chapel in frontierville? why do they get all the good stuff?	0
+4216	r u &sam p in eachother. if we meet we can go 2 my house	0
+4217	geeee ... i miss you already, you know ? your all i can think about. fuck, i can't wait till next year when we will be together ... *loving kiss*	0
+4218	hope you??re not having too much fun without me!! see u tomorrow love jess x	0
+4219	hi baby im sat on the bloody bus at the mo and i wont be home until about 7:30 wanna do somethin later? call me later ortxt back jess xx	0
+4220	what i mean is do they come chase you out when its over or is it stated you can watch as many movies as you want.	0
+4221	oops i did have it,  &lt;#&gt; ?	0
+4222	alrite sam its nic just checkin that this is ur number-so is it?t.b*	0
+4223	i'm really sorry i won't b able 2 do this friday.hope u can find an alternative.hope yr term's going ok:-)	0
+4224	hi jon, pete here, ive bin 2 spain recently & hav sum dinero left, bill said u or ur ??rents mayb interested in it, i hav 12,000pes, so around ??48, tb, james.	0
+4225	its ok..come to my home it vl nice to meet and v can chat..	0
+4226	?? eatin later but i'm eatin wif my frens now lei... ?? going home first?	0
+4227	ok then i'll let him noe later n ask him call u tmr...	0
+4228	your chance to be on a reality fantasy show call now = 08707509020 just 20p per min ntt ltd, po box 1327 croydon cr9 5wb 0870 is a national = rate call	1
+4229	i want some cock! my hubby's away, i need a real man 2 satisfy me. txt wife to 89938 for no strings action. (txt stop 2 end, txt rec ??1.50ea. otbox 731 la1 7ws. )	1
+4230	babe ! how goes that day ? what are you doing ? where are you ? i sip my cappuccino and think of you, my love ... i send a kiss to you from across the sea	0
+4231	urgent!: your mobile no. was awarded a ??2,000 bonus caller prize on 02/09/03! this is our 2nd attempt to contact you! call 0871-872-9755 box95qu	1
+4232	not getting anywhere with this damn job hunting over here!	0
+4233	wen ur lovable bcums angry wid u, dnt take it seriously.. coz being angry is d most childish n true way of showing deep affection, care n luv!.. kettoda manda... have nice day da.	0
+4234	what do you do, my dog ? must i always wait till the end of your day to have word from you ? did you run out of time on your cell already?	0
+4235	how stupid to say that i challenge god.you dont think at all on what i write instead you respond immed.	0
+4236	make sure alex knows his birthday is over in fifteen minutes as far as you're concerned	0
+4237	congrats! 1 year special cinema pass for 2 is yours. call 09061209465 now! c suprman v, matrix3, starwars3, etc all 4 free! bx420-ip4-5we. 150pm. dont miss out!	1
+4238	private! your 2003 account statement for <fone no> shows 800 un-redeemed s. i. m. points. call 08715203656 identifier code: 42049 expires 26/10/04	1
+4239	jay is snickering and tells me that x is totally fucking up the chords as we speak	0
+4240	left dessert. u wan me 2 go suntec look 4 u?	0
+4241	she.s good. she was wondering if you wont say hi but she.s smiling now. so how are you coping with the long distance	0
+4242	where's mummy's boy ? is he being good or bad ? is he being positive or negative ? why is mummy being made to wait? hmmmm?	0
+4243	no:-)i got rumour that you going to buy apartment in chennai:-)	0
+4244	sexy singles are waiting for you! text your age followed by your gender as wither m or f e.g.23f. for gay men text your age followed by a g. e.g.23g.	1
+4245	freemsg hey u, i just got 1 of these video/pic fones, reply wild to this txt & ill send u my pics, hurry up im so bored at work xxx (18 150p/rcvd stop2stop)	1
+4246	went to pay rent. so i had to go to the bank to authorise the payment.	0
+4247	k.k:)advance happy pongal.	0
+4248	wat time ?_ wan today?	0
+4249	a swt thought: \nver get tired of doing little things 4 lovable persons..\" coz..somtimes those little things occupy d biggest part in their hearts.. gud ni8"	0
+4250	we still on for tonight?	0
+4251	no i'm in the same boat. still here at my moms. check me out on yo. i'm half naked.	0
+4252	what i mean was i left too early to check, cos i'm working a 9-6.	0
+4253	have got * few things to do. may be in * pub later.	0
+4254	?? only send me the contents page...	0
+4255	omg joanna is freaking me out. she's looked thru all my friends to find photos of me. and then she's asking about stuff on my myspace which i haven't even logged on in like a year. :/	0
+4256	sweetheart, hope you are not having that kind of day! have one with loads of reasons to smile. biola	0
+4257	sounds great! im going to sleep now. have a good night!	0
+4258	we left already we at orchard now.	0
+4259	of course ! don't tease me ... you know i simply must see ! *grins* ... do keep me posted my prey ... *loving smile* *devouring kiss*	0
+4260	remember all those whom i hurt during days of satanic imposter in me.need to pay a price,so be it.may destiny keep me going and as u said pray that i get the mind to get over the same.	0
+4261	\none!nowhere ikno doesdiscount!shitinnit\""	0
+4262	i asked sen to come chennai and search for job.	0
+4263	i'm sick !! i'm needy !! i want you !! *pouts* *stomps feet* where are you ?! *pouts* *stomps feet* i want my slave !! i want him now !!	0
+4264	i can't speak, bcaz mobile have problem. i can listen you but you cann't listen my voice. so i calls you later.	0
+4265	i'm in class. did you get my text.	0
+4266	k k :-):-) then watch some films.	0
+4267	anything lor if they all go then i go lor...	0
+4268	thank you meet you monday	0
+4269	ok ill tell the company	0
+4270	have your lunch and come quickly and open the door:)	0
+4271	hey...great deal...farm tour 9am to 5pm $95/pax, $50 deposit by 16 may	0
+4272	yeah he got in at 2 and was v apologetic. n had fallen out and she was actin like spoilt child and he got caught up in that. till 2! but we won't go there! not doing too badly cheers. you?	0
+4273	k..then come wenever u lik to come and also tel vikky to come by getting free time..:-)	0
+4274	ard 4 lor...	0
+4275	no problem. talk to you later	0
+4276	do have a nice day today. i love you so dearly.	0
+4277	how are you babes. hope your doing ok. i had a shit nights sleep. i fell asleep at 5.i??m knackered and i??m dreading work tonight. what are thou upto tonight. x	0
+4278	yes we are chatting too.	0
+4279	i'm in office now da:)where are you?	0
+4280	bears pic nick, and tom, pete and ... dick. in fact, all types try gay chat with photo upload call 08718730666 (10p/min). 2 stop texts call 08712460324	1
+4281	no b4 thursday	0
+4282	i will lick up every drop :) are you ready to use your mouth as well?	0
+4283	free ringtone text first to 87131 for a poly or text get to 87131 for a true tone! help? 0845 2814032 16 after 1st free, tones are 3x??150pw to e??nd txt stop	1
+4284	great princess! i love giving and receiving oral. doggy style is my fave position. how about you? i enjoy making love  &lt;#&gt;  times per night :)	0
+4285	i do know what u mean,  is the king of not havin credit! i'm goin2bed now. night night sweet! only1more sleep!	0
+4286	\hi babe uawake?feellikw shit.justfound out via aletter thatmum gotmarried 4thnov.behind ourbacks ?? fuckinnice!selfish i??l call u\""	0
+4287	gosh that , what a pain. spose i better come then.	0
+4288	congratulations! thanks to a good friend u have won the ??2,000 xmas prize. 2 claim is easy, just call 08718726971 now! only 10p per minute. bt-national-rate.	1
+4289	love you aathi..love u lot..	0
+4290	sir, i am waiting for your mail.	0
+4291	in which place do you want da.	0
+4292	reason is if the team budget is available at last they buy the unsold players for at base rate..	0
+4293	sorry, i'll call later	0
+4294	i want to go to perumbavoor	0
+4295	are u awake? is there snow there?	0
+4296	hiya comin 2 bristol 1 st week in april. les got off + rudi on new yrs eve but i was snoring.they were drunk! u bak at college yet? my work sends ink 2 bath.	0
+4297	i dun believe u. i thk u told him.	0
+4298	daddy will take good care of you :)	0
+4299	yeah sure thing mate haunt got all my stuff sorted but im going sound anyway promoting hex for .by the way who is this? dont know number. joke	0
+4300	your account has been refilled successfully by inr  &lt;decimal&gt; . your keralacircle prepaid account balance is rs  &lt;decimal&gt; . your transaction id is kr &lt;#&gt; .	0
+4301	guai... ?? shd haf seen him when he's naughty... ?? so free today? can go jogging...	0
+4302	he said i look pretty wif long hair wat. but i thk he's cutting quite short 4 me leh.	0
+4303	haha yeah, 2 oz is kind of a shitload	0
+4304	s but mostly not like that.	0
+4305	gr8 new service - live sex video chat on your mob - see the sexiest dirtiest girls live on ur phone - 4 details text horny to 89070 to cancel send stop to 89070	1
+4306	lemme know when i can swing by and pick up, i'm free basically any time after 1 all this semester	0
+4307	okie...	0
+4308	1 i don't have her number and 2 its gonna be a massive pain in the ass and i'd rather not get involved if that's possible	0
+4309	hiya. how was last night? i've been naughty and bought myself clothes and very little ... ready for more shopping tho! what kind of time do you wanna meet?	0
+4310	my parents, my kidz, my friends n my colleagues. all screaming.. surprise !! and i was waiting on the sofa.. ... ..... ' naked...!	0
+4311	why are u up so early?	0
+4312	tomorrow i am not going to theatre. . . so i can come wherever u call me. . . tell me where and when to come tomorrow	0
+4313	dunno lei ?_ all decide lor. how abt leona? oops i tot ben is going n i msg him.	0
+4314	camera - you are awarded a sipix digital camera! call 09061221066 fromm landline. delivery within 28 days.	1
+4315	how to make a girl happy? it's not at all difficult to make girls happy. u only need to be... 1. a friend 2. companion 3. lover 4. chef . . .  &lt;#&gt; . good listener  &lt;#&gt; . organizer  &lt;#&gt; . good boyfriend  &lt;#&gt; . very clean  &lt;#&gt; . sympathetic  &lt;#&gt; . athletic  &lt;#&gt; . warm . . .  &lt;#&gt; . courageous  &lt;#&gt; . determined  &lt;#&gt; . true  &lt;#&gt; . dependable  &lt;#&gt; . intelligent . . .  &lt;#&gt; . psychologist  &lt;#&gt; . pest exterminator  &lt;#&gt; . psychiatrist  &lt;#&gt; . healer . .  &lt;#&gt; . stylist  &lt;#&gt; . driver . . aaniye pudunga venaam..	0
+4316	and very importantly, all we discuss is between u and i only.	0
+4317	tell me pa. how is pain de.	0
+4318	aiyo please ?_ got time meh.	0
+4319	i have a rather prominent bite mark on my right cheek	0
+4320	once free call me sir. i am waiting for you.	0
+4321	aight, tomorrow around  &lt;#&gt;  it is	0
+4322	thanx u darlin!im cool thanx. a few bday drinks 2 nite. 2morrow off! take care c u soon.xxx	0
+4323	urgent! we are trying to contact u. todays draw shows that you have won a ??800 prize guaranteed. call 09050001808 from land line. claim m95. valid12hrs only	1
+4324	message from . i am at truro hospital on ext. you can phone me here. as i have a phone by my side	0
+4325	can a not?	0
+4326	yeah there's barely enough room for the two of us, x has too many fucking shoes. sorry man, see you later	0
+4327	she.s fine. i have had difficulties with her phone. it works with mine. can you pls send her another friend request.	0
+4328	my birthday is on feb  &lt;#&gt;  da. .	0
+4329	dhoni have luck to win some big title.so we will win:)	0
+4330	hmmm... guess we can go 4 kb n power yoga... haha, dunno we can tahan power yoga anot... thk got lo oso, forgot liao...	0
+4331	nothing, i got msg frm tht unknown no..	0
+4332	yep then is fine 7.30 or 8.30 for ice age.	0
+4333	so ?_ pay first lar... then when is da stock comin...	0
+4334	yo sorry was in the shower sup	0
+4335	u can win ??100 of music gift vouchers every week starting now txt the word draw to 87066 tscs www.ldew.com skillgame,1winaweek, age16.150ppermesssubscription	1
+4336	just got to  &lt;#&gt;	0
+4337	i love your ass! do you enjoy doggy style? :)	0
+4338	really dun bluff me leh... u sleep early too. nite...	0
+4339	jus came back fr lunch wif my sis only. u leh?	0
+4340	we not leaving yet. ok lor then we go elsewhere n eat. u thk...	0
+4341	you have 1 new voicemail. please call 08719181503	1
+4342	we can make a baby in yo tho	0
+4343	you are now unsubscribed all services. get tons of sexy babes or hunks straight to your phone! go to http://gotbabes.co.uk. no subscriptions.	1
+4344	dear we got  &lt;#&gt;  dollars hi hi	0
+4345	sunshine quiz wkly q! win a top sony dvd player if u know which country liverpool played in mid week? txt ansr to 82277. ??1.50 sp:tyrone	1
+4346	you have been specially selected to receive a \3000 award! call 08712402050 before the lines close. cost 10ppm. 16+. t&cs apply. ag promo"	1
+4347	did you hear about the new \divorce barbie\"? it comes with all of ken's stuff!"	1
+4348	sun cant come to earth but send luv as rays. cloud cant come to river but send luv as rain. i cant come to meet u, but can send my care as msg to u. gud evng	0
+4349	oh... i was thkin of goin yogasana at 10 den no nd to go at 3 den can rush to parco 4 nb... okie lor, u call me when ready...	0
+4350	i don't run away frm u... i walk slowly &amp; it kills me that u don't care enough to stop me...	0
+4351	\the world suffers a lot... not because of the violence of bad people. but because of the silence of good people!\"	0
+4352	tddnewsletter@emc1.co.uk (more games from thedailydraw) dear helen, dozens of free games - with great prizeswith..	1
+4353	fuck babe ... i miss you already, you know ? can't you let me send you some money towards your net ? i need you ... i want you ... i crave you ...	0
+4354	feb  &lt;#&gt;  is \i love u\" day. send dis to all ur \"valued frnds\" evn me. if 3 comes back u'll gt married d person u luv! if u ignore dis u will lose ur luv 4 evr"	0
+4355	i'm doing da intro covers energy trends n pros n cons... brief description of nuclear fusion n oso brief history of iter n jet got abt 7 n half pages..	0
+4356	sex up ur mobile with a free sexy pic of jordan! just text babe to 88600. then every wk get a sexy celeb! pocketbabe.co.uk 4 more pics. 16 ??3/wk 087016248	1
+4357	oops - am at my mum's in somerset... bit far! back tomo, see you soon x	0
+4358	busy here. trying to finish for new year. i am looking forward to finally meeting you...	0
+4359	no plm i will come da. on the way.	0
+4360	romantic paris. 2 nights, 2 flights from ??79 book now 4 next year. call 08704439680ts&cs apply.	1
+4361	i am going to sleep. i am tired of travel.	0
+4362	wah lucky man... then can save money... hee...	0
+4363	midnight at the earliest	0
+4364	got what it takes 2 take part in the wrc rally in oz? u can with lucozade energy! text rally le to 61200 (25p), see packs or lucozade.co.uk/wrc & itcould be u!	1
+4365	yes when is the appt again?	0
+4366	i donno if they are scorable	0
+4367	back 2 work 2morro half term over! can u c me 2nite 4 some sexy passion b4 i have 2 go back? chat now 09099726481 luv dena calls ??1/minmobsmorelkpobox177hp51fl	1
+4368	dont you have message offer	0
+4369	that means you got an a in epi, she.s fine. she.s here now.	0
+4370	when are you going to ride your bike?	0
+4371	dude u knw also telugu..thts gud..k, gud nyt..	0
+4372	oh k. . i will come tomorrow	0
+4373	in the simpsons movie released in july 2007 name the band that died at the start of the film? a-green day, b-blue day, c-red day. (send a, b or c)	1
+4374	what time should i tell my friend to be around?	0
+4375	ha... u jus ate honey ar? so sweet...	0
+4376	i have printed it oh. so  &lt;#&gt;  come upstairs	0
+4377	7 wonders in my world 7th you 6th ur style 5th ur smile 4th ur personality 3rd ur nature 2nd ur sms and 1st \ur lovely friendship\"... good morning dear"	0
+4378	i love working from home :)	0
+4379	friendship is not a game to play, it is not a word to say, it doesn\'t start on march and ends on may, it is tomorrow, yesterday, today and e	0
+4380	then u better go sleep.. dun disturb u liao.. u wake up then msg me lor..	0
+4381	change windows logoff sound..	0
+4382	sorry, i'll call later	0
+4383	yup... ok i go home look at the timings then i msg ?_ again... xuhui going to learn on 2nd may too but her lesson is at 8am	0
+4384	hi, this is mandy sullivan calling from hotmix fm...you are chosen to receive ??5000.00 in our easter prize draw.....please telephone 09041940223 to claim before 29/03/05 or your prize will be transferred to someone else....	1
+4385	hiya do u like the hlday pics looked horrible in them so took mo out! hows the camp amrca thing? speak soon serena:)	0
+4386	wat time u finish ur lect today?	0
+4387	u sleeping now.. or you going to take? haha.. i got spys wat.. me online checking n replying mails lor..	0
+4388	filthy stories and girls waiting for your	1
+4389	sorry chikku, my cell got some problem thts y i was nt able to reply u or msg u..	0
+4390	can you call me plz. your number shows out of coveragd area. i have urgnt call in vasai &amp; have to reach before 4'o clock so call me plz	0
+4391	free entry in 2 a weekly comp for a chance to win an ipod. txt pod to 80182 to get entry (std txt rate) t&c's apply 08452810073 for details 18+	1
+4392	to day class is there are no class.	0
+4393	good night my dear.. sleepwell&amp;take care	0
+4394	how much you got for cleaning	0
+4395	wewa is 130. iriver 255. all 128 mb.	0
+4396	i'm still pretty weak today .. bad day ?	0
+4397	no pic. please re-send.	0
+4398	had your mobile 11mths ? update for free to oranges latest colour camera mobiles & unlimited weekend calls. call mobile upd8 on freefone 08000839402 or 2stoptx	1
+4399	k. did you call me just now ah?	0
+4400	ok . . now i am in bus. . if i come soon i will come otherwise tomorrow	0
+4401	this weekend is fine (an excuse not to do too much decorating)	0
+4402	hey... what time is your driving on fri? we go for evaluation on fri?	0
+4403	all was well until slightly disastrous class this pm with my fav darlings! hope day off ok. coffee wld be good as can't stay late tomorrow. same time + place as always?	0
+4404	well i'm going to be an aunty!	0
+4405	not to worry. i'm sure you'll get it.	0
+4406	i'm coming back on thursday. yay. is it gonna be ok to get the money. cheers. oh yeah and how are you. everything alright. hows school. or do you call it work now	0
+4407	can you use foreign stamps for whatever you send them off for?	0
+4408	oh wow thats gay. will firmware update help	0
+4409	asking do u knw them or nt? may be ur frnds or classmates?	0
+4410	87077: kick off a new season with 2wks free goals & news to ur mobile! txt ur club name to 87077 eg villa to 87077	1
+4411	dear we are going to our rubber place	0
+4412	yes, i'm small kid.. and boost is the secret of my energy..	0
+4413	'an amazing quote'' - \sometimes in life its difficult to decide whats wrong!! a lie that brings a smile or the truth that brings a tear....\""	0
+4414	cps is causing the outages to conserve energy.	0
+4415	bangbabes ur order is on the way. u should receive a service msg 2 download ur content. if u do not, goto wap. bangb. tv on ur mobile internet/service menu	1
+4416	still in customer place	0
+4417	i don know account details..i will ask my mom and send you.my mom is out of reach now.	0
+4418	helloooo... wake up..! \sweet\" \"morning\" \"welcomes\" \"you\" \"enjoy\" \"this day\" \"with full of joy\".. \"gud mrng\"."	0
+4419	yes baby! i need to stretch open your pussy!	0
+4420	should i buy him a blackberry bold 2 or torch. should i buy him new or used. let me know. plus are you saying i should buy the  &lt;#&gt; g wifi ipad. and what are you saying about the about the  &lt;#&gt; g?	0
+4421	hey leave it. not a big deal:-) take care.	0
+4422	nope i waiting in sch 4 daddy...	0
+4423	lol that's different. i don't go trying to find every real life photo you ever took.	0
+4424	okey dokey, i???ll be over in a bit just sorting some stuff out.	0
+4425	important information 4 orange user . today is your lucky day!2find out why log onto http://www.urawinner.com there's a fantastic surprise awaiting you!	1
+4426	ok omw now, you at castor?	0
+4427	any chance you might have had with me evaporated as soon as you violated my privacy by stealing my phone number from your employer's paperwork. not cool at all. please do not contact me again or i will report you to your supervisor.	0
+4428	what do u want for xmas? how about 100 free text messages & a new video phone with half price line rental? call free now on 0800 0721072 to find out more!	1
+4429	i've reached home n i bathe liao... u can call me now...	0
+4430	lets use it next week, princess :)	0
+4431	you do got a shitload of diamonds though	0
+4432	so is th gower mate which is where i am!?! how r u man? all is good in wales ill b back ??morrow. c u this wk? who was the msg 4? ?? random!	0
+4433	sounds like there could be a lot of time spent in that chastity device boy ... *grins* ... or take your beatings like a good dog. going to lounge in a nice long bath now ?	0
+4434	camera - you are awarded a sipix digital camera! call 09061221066 fromm landline. delivery within 28 days.	1
+4435	we tried to contact you re your reply to our offer of 750 mins 150 textand a new video phone call 08002988890 now or reply for free delivery tomorrow	1
+4436	good morning my dear shijutta........... have a great &amp; successful day.	0
+4437	i couldn't say no as he is a dying man and i feel sad for him so i will go and i just wanted you to know i would probably be gone late into your night	0
+4438	new tones this week include: 1)mcfly-all ab.., 2) sara jorge-shock.. 3) will smith-switch.. to order follow instructions on next message	1
+4439	hmm yeah if your not too grooved out! and im looking forward to my pound special :)	0
+4440	happy new year princess!	0
+4441	no idea, i guess we'll work that out an hour after we're supposed to leave since as usual nobody has any interest in figuring shit out before the last second	0
+4442	sure, but make sure he knows we ain't smokin yet	0
+4443	don't forget though that i love you .... and i walk beside you. watching over you and keeping your heart warm.	0
+4444	ok cool. see ya then.	0
+4445	especially since i talk about boston all up in my personal statement, lol! i woulda changed that if i had realized it said nyc! it says boston now.	0
+4446	for ur chance to win a ??250 cash every wk txt: action to 80608. t's&c's www.movietrivia.tv custcare 08712405022, 1x150p/wk.	1
+4447	can you plz tell me the ans. bslvyl sent via fullonsms.com	0
+4448	what happened in interview?	0
+4449	no, but you told me you were going, before you got drunk!	0
+4450	u buy newspapers already?	0
+4451	erm. i thought the contract ran out the4th of october.	0
+4452	yup	0
+4453	got meh... when?	0
+4454	its good to hear from you	0
+4455	i had askd u a question some hours before. its answer	0
+4456	have you heard from this week?	0
+4457	r ?_ going 4 today's meeting?	0
+4458	gud mrng dear have a nice day	0
+4459	its ok., i just askd did u knw tht no?	0
+4460	i am late. i will be there at	0
+4461	for ur chance to win a ??250 cash every wk txt: action to 80608. t's&c's www.movietrivia.tv custcare 08712405022, 1x150p/wk	1
+4462	what happen to her tell the truth	0
+4463	sorry, i'll call later	0
+4464	even u dont get in trouble while convincing..just tel him once or twice and just tel neglect his msgs dont c and read it..just dont reply	0
+4465	where did u go? my phone is gonna die you have to stay in here	0
+4466	sorry, i'll call later	0
+4467	were trying to find a chinese food place around here	0
+4468	hmm ok, i'll stay for like an hour cos my eye is really sore!	0
+4469	goodmorning sleeping ga.	0
+4470	ee msg na poortiyagi odalebeku: hanumanji 7 name 1-hanuman 2-bajarangabali 3-maruti 4-pavanaputra 5-sankatmochan 6-ramaduth 7-mahaveer ee 7 name  &lt;#&gt;  janarige ivatte kalisidare next saturday olage ondu good news keluviri...! maretare inde 1 dodda problum nalli siguviri idu matra  &lt;#&gt; % true.. don't neglet.	0
+4471	free for 1st week! no1 nokia tone 4 ur mob every week just txt nokia to 8007 get txting and tell ur mates www.getzed.co.uk pobox 36504 w45wq norm150p/tone 16+	1
+4472	pansy! you've been living in a jungle for two years! its my driving you should be more worried about!	0
+4473	hi, wlcome back, did wonder if you got eaten by a lion or something, nothing much	0
+4474	hiya , have u been paying money into my account? if so, thanks. got a pleasant surprise when i checked my balance -u c, i don't get statements 4 that acc	0
+4475	sir, i need axis bank account no and bank address.	0
+4476	urgent! call 09066350750 from your landline. your complimentary 4* ibiza holiday or 10,000 cash await collection sae t&cs po box 434 sk3 8wp 150 ppm 18+	1
+4477	our dating service has been asked 2 contact u by someone shy! call 09058091870 now all will be revealed. pobox84, m26 3uz 150p	1
+4478	camera quite good, 10.1mega pixels, 3optical and 5digital dooms. have a lovely holiday, be safe and i hope you hav a good journey! happy new year to you both! see you in a couple of weeks!	0
+4479	i've sent ?_ my part..	0
+4480	hmmm...k...but i want to change the field quickly da:-)i wanna get system administrator or network administrator..	0
+4481	?? bot notes oredi... cos i juz rem i got...	0
+4482	just dropped em off, omw back now	0
+4483	have a lovely night and when you wake up to see this message, i hope you smile knowing all is as should be. have a great morning	0
+4484	in that case i guess i'll see you at campus lodge	0
+4485	wait.i will come out.. &lt;#&gt;  min:)	0
+4486	valentines day special! win over ??1000 in our quiz and take your partner on the trip of a lifetime! send go to 83600 now. 150p/msg rcvd. custcare:08718720201	1
+4487	she told to hr that he want posting in chennai:)because i'm working here:)	0
+4488	will do, you gonna be at blake's all night? i might be able to get out of here a little early	0
+4489	do you want a new nokia 3510i colour phone delivered tomorrow? with 200 free minutes to any mobile + 100 free text + free camcorder reply or call 8000930705	1
+4490	great! so what attracts you to the brothas?	0
+4491	december only! had your mobile 11mths+? you are entitled to update to the latest colour camera mobile for free! call the mobile update vco free on 08002986906	1
+4492	happy new year. hope you are having a good semester	0
+4493	no calls..messages..missed calls	0
+4494	aathi..where are you dear..	0
+4495	what's up bruv, hope you had a great break. do have a rewarding semester.	0
+4496	hmmm.but you should give it on one day..	0
+4497	thing r good thanx got exams in march ive done no revision? is fran still with boyf? ive gotta interviw 4 exeter bit worried!x	0
+4498	hi babe its me thanks for coming even though it didnt go that well!i just wanted my bed! hope to see you soon love and kisses xxx	0
+4499	okay name ur price as long as its legal! wen can i pick them up? y u ave x ams xx	0
+4500	urgent! please call 0906346330. your abta complimentary 4* spanish holiday or ??10,000 cash await collection sae t&cs box 47 po19 2ez 150ppm 18+	1
+4501	good morning. at the repair shop--the only reason i'm up at this hour.	0
diff --git a/data/sms/valid.tsv b/data/sms/valid.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..3ebfa4fd0cae10d36b5b8baec03af0daa3ce10d7
--- /dev/null
+++ b/data/sms/valid.tsv
@@ -0,0 +1,501 @@
+index	sentence1	label
+0	i'm reading the text i just sent you. its meant to be a joke. so read it in that light	0
+1	reply to win ??100 weekly! what professional sport does tiger woods play? send stop to 87239 to end service	1
+2	\shit babe.. thasa bit messed up.yeh  illspeak 2 u2moro wen im not asleep...\"" illspeak 2 u2moro wen im not asleep...\""	0
+3	my sister got placed in birla soft da:-)	0
+4	1) go to write msg 2) put on dictionary mode 3)cover the screen with hand, 4)press  &lt;#&gt; . 5)gently remove ur hand.. its interesting..:)	0
+5	one small prestige problem now.	0
+6	i want snow. it's just freezing and windy.	0
+7	wishing you and your family merry \x\" mas and happy new year in advance.."	0
+8	mah b, i'll pick it up tomorrow	0
+9	o. guess they both got screwd	0
+10	i think that tantrum's finished so yeah i'll be by at some point	0
+11	gud mrng dear hav a nice day	0
+12	ok. i asked for money how far	0
+13	this is all just creepy and crazy to me.	0
+14	well. balls. time to make calls	0
+15	i???ll leave around four, ok?	0
+16	urgent! we are trying to contact u. todays draw shows that you have won a ??800 prize guaranteed. call 09050003091 from land line. claim c52. valid12hrs only	1
+17	no, i decided that only people who care about stuff vote and caring about stuff is for losers	0
+18	yeah there's quite a bit left, i'll swing by tomorrow when i get up	0
+19	this message is brought to you by gmw ltd. and is not connected to the	1
+20	cool, text me when you're ready	0
+21	reading gud habit.. nan bari hudgi yorge pataistha ertini kano:-)	0
+22	playin space poker, u?	0
+23	doing project w frens lor.	0
+24	purity of friendship between two is not about smiling after reading the forwarded message..its about smiling just by seeing the name. gud evng	0
+25	just got outta class gonna go gym.	0
+26	what i meant to say is cant wait to see u again getting bored of this bridgwater banter	0
+27	urgent! your mobile number has been awarded with a ??2000 bonus caller prize. call 09058095201 from land line. valid 12hrs only	1
+28	yeah, where's your class at?	0
+29	i want to see your pretty pussy...	0
+30	double mins & 1000 txts on orange tariffs. latest motorola, sonyericsson & nokia with bluetooth free! call mobileupd8 on 08000839402 or call2optout/hf8	1
+31	the fact that you're cleaning shows you know why i'm upset. your priority is constantly \what i want to do	0
+32	when i have stuff to sell i.ll tell you	0
+33	ofcourse i also upload some songs	0
+34	yo dude guess who just got arrested the other day	0
+35	yeah get the unlimited	0
+36	honey, can you pls find out how much they sell predicte in nigeria. and how many times can it be used. its very important to have a reply before monday	0
+37	so lets make it saturday or monday as per convenience.	0
+38	k tell me anything about you.	0
+39	free msg:we billed your mobile number by mistake from shortcode 83332.please call 08081263000 to have charges refunded.this call will be free from a bt landline	1
+40	from 5 to 2 only my work timing.	0
+41	hey what how about your project. started aha da.	0
+42	yo, call me when you get the chance, a friend of mine wanted me to ask you about a big order	0
+43	i'm going out to buy mum's present ar.	0
+44	important information 4 orange user 0789xxxxxxx. today is your lucky day!2find out why log onto http://www.urawinner.com there's a fantastic surprise awaiting you!	1
+45	hmm ill have to think about it... ok you're forgiven! =d	0
+46	jolly good! by the way,  will give u tickets for sat eve 7.30. speak before then x	0
+47	hey! there's veggie pizza... :/	0
+48	its  &lt;#&gt; k here oh. should i send home for sale.	0
+49	yeah, don't go to bed, i'll be back before midnight	0
+50	watching ajith film ah?	0
+51	motivate behind every darkness, there is a shining light waiting for you to find it... behind every best friend, there is always trust and love... bslvyl	0
+52	yeah that's the impression i got	0
+53	was playng 9 doors game and gt racing on phone lol	0
+54	you have registered sinco as payee. log in at icicibank.com and enter urn  &lt;#&gt;  to confirm. beware of frauds. do not share or disclose urn to anyone.	0
+55	it is only yesterday true true.	0
+56	i know you are. can you pls open the back?	0
+57	let there be snow. let there be snow. this kind of weather brings ppl together so friendships can grow.	0
+58	dont hesitate. you know this is the second time she has had weakness like that. so keep i notebook of what she eat and did the day before or if anything changed the day before so that we can be sure its nothing	0
+59	and do you have any one that can teach me how to ship cars.	0
+60	but i'm on a diet. and i ate 1 too many slices of pizza yesterday. ugh i'm always on a diet.	0
+61	hey gorgeous man. my work mobile number is. have a good one babe. squishy mwahs.	0
+62	wife.how she knew the time of murder exactly	0
+63	urgent we are trying to contact you last weekends draw shows u have won a ??1000 prize guaranteed call 09064017295 claim code k52 valid 12hrs 150p pm	1
+64	urgent! please call 09061213237 from a landline. ??5000 cash or a 4* holiday await collection. t &cs sae po box 177 m227xy. 16+	1
+65	i will vote for wherever my heart guides me	0
+66	oh ! a half hour is much longer in syria than canada, eh ? wow you must get so much more work done in a day than us with all that extra time ! *grins*	0
+67	themob>hit the link to get a premium pink panther game, the new no. 1 from sugababes, a crazy zebra animation or a badass hoody wallpaper-all 4 free!	1
+68	hi petey!noi??m ok just wanted 2 chat coz avent spoken 2 u 4 a long time-hope ur doin alrite.have good nit at js love ya am.x	0
+69	same as u... dun wan... y u dun like me already ah... wat u doing now? still eating?	0
+70	see you then, we're all christmassy here!	0
+71	s da..al r above  &lt;#&gt;	0
+72	guessin you ain't gonna be here before 9?	0
+73	awesome, lemme know whenever you're around	0
+74	heehee that was so funny tho	0
+75	kallis wont bat in 2nd innings.	0
+76	bbq this sat at mine from 6ish. ur welcome 2 come	0
+77	v-aluable. a-ffectionate. l-oveable. e-ternal. n-oble. t-ruthful. i-ntimate. n-atural. e-namous. happy \valentines day\" in advance"	0
+78	storming msg: wen u lift d phne, u say \hello\" do u knw wt is d real meaning of hello?? . . . it's d name of a girl..! . . . yes.. and u knw who is dat girl?? \"margaret hello\" she is d girlfrnd f grahmbell who invnted telphone... . . . . moral:one can 4get d name of a person	0
+79	u definitely need a module from e humanities dis sem izzit? u wan 2 take other modules 1st?	0
+80	hey doc pls i want to get nice t shirt for my hubby nice fiting ones my budget is  &lt;#&gt; k help pls i will load d card abi hw,keep me posted luv. 2 mj	0
+81	so i could kiss and feel you next to me...	0
+82	what about this one then.	0
+83	freemsg you have been awarded a free mini digital camera, just reply snap to collect your prize! (quizclub opt out? stop 80122300p/wk sp:rwm ph:08704050406)	1
+84	ooooooh i forgot to tell u i can get on yoville on my phone	0
+85	then cant get da laptop? my matric card wif ?_ lei...	0
+86	themob> check out our newest selection of content, games, tones, gossip, babes and sport, keep your mobile fit and funky text wap to 82468	1
+87	sorry my roommates took forever, it ok if i come by now?	0
+88	call 09090900040 & listen to extreme dirty live chat going on in the office right now total privacy no one knows your [sic] listening 60p min 24/7mp 0870753331018+	1
+89	i know but you need to get hotel now. i just got my invitation but i had to apologise. cali is to sweet for me to come to some english bloke's weddin	0
+90	any way where are you and what doing.	0
+91	i've not sent it. he can send me.	0
+92	complimentary 4 star ibiza holiday or ??10,000 cash needs your urgent collection. 09066364349 now from landline not to lose out! box434sk38wp150ppm18+	1
+93	private! your 2003 account statement for 07815296484 shows 800 un-redeemed s.i.m. points. call 08718738001 identifier code 41782 expires 18/11/04	1
+94	i am real, baby! i want to bring out your inner tigress...	0
+95	wat makes u thk i'll fall down. but actually i thk i'm quite prone 2 falls. lucky my dad at home i ask him come n fetch me already.	0
+96	hi.:)technical support.providing assistance to us customer through call and email:)	0
+97	you stayin out of trouble stranger!!saw dave the other day he??s sorted now!still with me bloke when u gona get a girl mr!ur mum still thinks we will get 2getha!	0
+98	i hav almost reached. call, i m unable to connect u.	0
+99	ok i also wan 2 watch e 9 pm show...	0
+100	yo come over carlos will be here soon	0
+101	black shirt n blue jeans... i thk i c ?_...	0
+102	aight ill get on fb in a couple minutes	0
+103	?? all write or wat..	0
+104	:-( sad puppy noise	0
+105	babe ! what are you doing ? where are you ? who are you talking to ? do you think of me ? are you being a good boy? are you missing me? do you love me ?	0
+106	if we hit it off, you can move in with me :)	0
+107	should i head straight there or what	0
+108	oh all have to come ah?	0
+109	wat's my dear doing? sleeping ah?	0
+110	audrie lousy autocorrect	0
+111	sorry pa, i dont knw who ru pa?	0
+112	k..k:)how about your training process?	0
+113	enjoy the showers of possessiveness poured on u by ur loved ones, bcoz in this world of lies, it is a golden gift to be loved truly..	0
+114	we took hooch for a walk toaday and i fell over! splat! grazed my knees and everything! should have stayed at home! see you tomorrow!	0
+115	in the end she might still vomit but its okay. not everything will come out.	0
+116	i keep ten rs in my shelf:) buy two egg.	0
+117	i'm ok wif it cos i like 2 try new things. but i scared u dun like mah. cos u said not too loud.	0
+118	free msg: get gnarls barkleys \crazy\" ringtone totally free just reply go to this message right now!"	1
+119	this is the 2nd attempt to contract u, you have won this weeks top prize of either ??1000 cash or ??200 prize. just call 09066361921	1
+120	no. but we'll do medical missions to nigeria	0
+121	becoz its  &lt;#&gt;  jan whn al the post ofice is in holiday so she cn go fr the post ofice...got it duffer	0
+122	good afternoon, my boytoy ... how are you feeling today ? better i hope? are you being my good boy? are you my obedient, slave? do you please your queen?	0
+123	aight i'll grab something to eat too, text me when you're back at mu	0
+124	no she didnt. i will search online and let you know.	0
+125	how will i creep on you now? ;_;	0
+126	well done england! get the official poly ringtone or colour flag on yer mobile! text tone or flag to 84199 now! opt-out txt eng stop. box39822 w111wx ??1.50	1
+127	yeah i should be able to, i'll text you when i'm ready to meet up	0
+128	joy's father is john. then john is the name of joy's father. mandan	0
+129	themob>yo yo yo-here comes a new selection of hot downloads for our members to get for free! just click & open the next link sent to ur fone...	1
+130	derp. which is worse, a dude who always wants to party or a dude who files a complaint about the three drug abusers he lives with	0
+131	give her something to drink, if she takes it and doesn't vomit then you her temp might drop. if she unmits however let me know.	0
+132	should i be stalking u?	0
+133	dare i ask... any luck with sorting out the car?	0
+134	lyricalladie(21/f) is inviting you to be her friend. reply yes-910 or no-910. see her: www.sms.ac/u/hmmross stop? send stop frnd to 62468	1
+135	joy's father is john. then john is the ____ of joy's father. if u ans ths you hav  &lt;#&gt;  iq. tis s ias question try to answer.	0
+136	i will reach ur home in  &lt;#&gt;  minutes	0
+137	true lov n care wil nevr go unrecognized. though somone often makes mistakes when valuing it. but they will definitly undrstnd once when they start missing it.	0
+138	lol no. i just need to cash in my nitros. hurry come on before i crash out!	0
+139	happy new year to u and ur family...may this new year bring happiness , stability and tranquility to ur vibrant colourful life:):)	0
+140	you have won a guaranteed ??1000 cash or a ??2000 prize. to claim yr prize call our customer service representative on 08714712379 between 10am-7pm cost 10p	1
+141	no da..today also i forgot..	0
+142	true dear..i sat to pray evening and felt so.so i sms'd you in some time...	0
+143	wat time do u wan 2 meet me later?	0
+144	u don't remember that old commercial?	0
+145	babe !!!! i love you !!!! *covers your face in kisses*	0
+146	\petey boy whereare you me and all your friendsare in thekingshead come down if you canlove nic\""	0
+147	this girl does not stay in bed. this girl doesn't need recovery time. id rather pass out while having fun then be cooped up in bed	0
+148	and stop being an old man. you get to build snowman snow angels and snowball fights.	0
+149	you are chosen to receive a ??350 award! pls call claim number 09066364311 to collect your award which you are selected to receive as a valued mobile customer.	1
+150	you have won a guaranteed ??1000 cash or a ??2000 prize.to claim yr prize call our customer service representative on	1
+151	haha awesome, i might need to take you up on that, what you doin tonight?	0
+152	he like not v shock leh. cos telling shuhui is like telling leona also. like dat almost all know liao. he got ask me abt ur reaction lor.	0
+153	good afternoon, my love. how goes your day ? what are you up to ? i woke early and am online waiting for you ... hmmm ... italian boy is online i see . *grins*	0
+154	no..its ful of song lyrics..	0
+155	got smaller capacity one? quite ex...	0
+156	howz pain?hope u r fine..	0
+157	warner village 83118 c colin farrell in swat this wkend @warner village & get 1 free med. popcorn!just show msg+ticket@kiosk.valid 4-7/12. c t&c @kiosk. reply sony 4 mre film offers	1
+158	u come n search tat vid..not finishd..	0
+159	yo, i'm at my parents' gettin cash. good news: we picked up a downstem	0
+160	the table's occupied, i'm waiting by the tree	0
+161	jay's getting really impatient and belligerent	0
+162	&lt;#&gt; , that's all? guess that's easy enough	0
+163	oh ok wait 4 me there... my lect havent finish	0
+164	see? i thought it all through	0
+165	sorry sent blank msg again. yup but trying 2 do some serious studying now.	0
+166	actually nvm, got hella cash, we still on for  &lt;#&gt; ish?	0
+167	good afternoon loverboy ! how goes you day ? any luck come your way? i think of you, sweetie and send my love across the sea to make you smile and happy	0
+168	to review and keep the fantastic nokia n-gage game deck with club nokia, go 2 www.cnupdates.com/newsletter. unsubscribe from alerts reply with the word out	1
+169	also that chat was awesome but don't make it regular unless you can see her in person	0
+170	k, wait chikku..il send aftr  &lt;#&gt; mins	0
+171	baaaaaaaabe! wake up ! i miss you ! i crave you! i need you!	0
+172	and i don't plan on staying the night but i prolly won't be back til late	0
+173	er yep sure. props?	0
+174	you flippin your shit yet?	0
+175	sorry, i'll call later	0
+176	im realy soz imat my mums 2nite what about 2moro	0
+177	private! your 2003 account statement for 07973788240 shows 800 un-redeemed s. i. m. points. call 08715203649 identifier code: 40533 expires 31/10/04	1
+178	aight, see you in a bit	0
+179	guess who am i?this is the first time i created a web page www.asjesus.com read all i wrote. i'm waiting for your opinions. i want to be your friend 1/1	1
+180	urgent! please call 09061213237 from landline. ??5000 cash or a luxury 4* canary islands holiday await collection. t&cs sae po box 177. m227xy. 150ppm. 16+	1
+181	no i am not having not any movies in my laptop	0
+182	how do you guys go to see movies on your side.	0
+183	think ur smart ? win ??200 this week in our weekly quiz, text play to 85222 now!t&cs winnersclub po box 84, m26 3uz. 16+. gbp1.50/week	1
+184	do u want 2 meet up 2morro	0
+185	we stopped to get ice cream and will go back after	0
+186	beautiful truth against gravity.. read carefully: \our heart feels light when someone is in it.. but it feels very heavy when someone leaves it..\" goodmorning"	0
+187	ok.	0
+188	urgent ur ??500 guaranteed award is still unclaimed! call 09066368327 now closingdate04/09/02 claimcode m39m51 ??1.50pmmorefrommobile2bremoved-mobypobox734ls27yf	1
+189	that???s the thing with apes, u can fight to the death to keep something, but the minute they have it when u let go, thats it!	0
+190	missed your call cause i was yelling at scrappy. miss u. can't wait for u to come home. i'm so lonely today.	0
+191	\can i please come up now imin town.dontmatter if urgoin outl8r u no thecd isv.important tome 4 2moro\""	0
+192	oh k. . i will come tomorrow	0
+193	jay told me already, will do	0
+194	call freephone 0800 542 0578 now!	1
+195	text her. if she doesnt reply let me know so i can have her log in	0
+196	oooh i got plenty of those!	0
+197	i'll probably be around mu a lot	0
+198	good sleep is about rhythm. the person has to establish a rhythm that the body will learn and use. if you want to know more :-)	0
+199	7 at esplanade.. do ?_ mind giving me a lift cos i got no car today..	0
+200	can u get 2 phone now? i wanna chat 2 set up meet call me now on 09096102316 u can cum here 2moro luv jane xx calls??1/minmoremobsemspobox45po139wa	1
+201	\happy valentines day\" i know its early but i have hundreds of handsomes and beauties to wish. so i thought to finish off aunties and uncles 1st..."	0
+202	dunno i juz askin cos i got a card got 20% off 4 a salon called hair sense so i tot it's da one ?_ cut ur hair.	0
+203	yeah my usual guy's out of town but there're definitely people around i know	0
+204	u have a secret admirer who is looking 2 make contact with u-find out who they r*reveal who thinks ur so special-call on 09065171142-stopsms-08718727870150ppm	1
+205	happy new year my no.1 man	0
+206	ur ringtone service has changed! 25 free credits! go to club4mobiles.com to choose content now! stop? txt club stop to 87070. 150p/wk club4 po box1146 mk45 2wt	1
+207	what???? hello wats talks email address?	0
+208	height of \oh shit....!!\" situation: a guy throws a luv letter on a gal but falls on her brothers head whos a gay	0
+209	tunji, how's the queen? how are you doing. this is just wishing you a great day. abiola.	0
+210	:-) :-)	0
+211	i cant pick the phone right now. pls send a message	0
+212	dear relieved of westonzoyland, all going to plan this end too!	0
+213	how long does it take to get it.	0
+214	hmmm ... and imagine after you've come home from that having to rub my feet, make me dinner and help me get ready for my date ! are you sure your ready for that kind of life ?	0
+215	neva mind it's ok..	0
+216	is there coming friday is leave for pongal?do you get any news from your work place.	0
+217	hows my favourite person today? r u workin hard? couldn't sleep again last nite nearly rang u at 4.30	0
+218	ya that one is slow as poo	0
+219	ard 530 lor. i ok then message ?_ lor.	0
+220	u can call now...	0
+221	no management puzzeles.	0
+222	there's someone here that has a year  &lt;#&gt;  toyota camry like mr olayiwola's own. mileage is  &lt;#&gt; k.its clean but i need to know how much will it sell for. if i can raise the dough for it how soon after landing will it sell. holla back.	0
+223	ok..	0
+224	its a site to simulate the test. it just gives you very tough questions to test your readiness.	0
+225	tap & spile at seven. * is that pub on gas st off broad st by canal. ok?	0
+226	u r the most beautiful girl ive ever seen. u r my baby come and c me in the common room	0
+227	you busy or can i come by at some point and figure out what we're doing tomorrow	0
+228	haven't found a way to get another app for your phone, eh ? will you go to the net cafe ? did you take that job? geeee i need you babe. i crave to see you ...	0
+229	ok. i am a gentleman and will treat you with dignity and respect.	0
+230	done it but internet connection v slow and can???t send it. will try again later or first thing tomo.	0
+231	i sent you  &lt;#&gt;  bucks	0
+232	free msg: single? find a partner in your area! 1000s of real people are waiting to chat now!send chat to 62220cncl send stopcs 08717890890??1.50 per msg	1
+233	the 2 oz guy is being kinda flaky but one friend is interested in picking up $ &lt;#&gt;  worth tonight if possible	0
+234	wan2 win a meet+greet with westlife 4 u or a m8? they are currently on what tour? 1)unbreakable, 2)untamed, 3)unkempt. text 1,2 or 3 to 83049. cost 50p +std text	1
+235	can... i'm free...	0
+236	do ?_ all wan 2 meet up n combine all the parts? how's da rest of da project going?	0
+237	im fine babes aint been up 2 much tho! saw scary movie yest its quite funny! want 2mrw afternoon? at town or mall or sumthin?xx	0
+238	your unique user id is 1172. for removal send stop to 87239 customer services 08708034412	1
+239	r u meeting da ge at nite tmr?	0
+240	get ready to moan and scream :)	0
+241	so many people seems to be special at first sight, but only very few will remain special to you till your last sight.. maintain them till life ends.. sh!jas	0
+242	hmmm ... i thought we said 2 hours slave, not 3 ... you are late ... how should i punish you ?	0
+243	call him and say you not coming today ok and tell them not to fool me like this ok	0
+244	had your mobile 11mths ? update for free to oranges latest colour camera mobiles & unlimited weekend calls. call mobile upd8 on freefone 08000839402 or 2stoptxt	1
+245	wen did you get so spiritual and deep. that's great	0
+246	xmas offer! latest motorola, sonyericsson & nokia & free bluetooth or dvd! double mins & 1000 txt on orange. call mobileupd8 on 08000839402 or call2optout/4qf2	1
+247	cbe is really good nowadays:)lot of shop and showrooms:)city is shaping good.	0
+248	carry on not disturbing both of you	0
+249	you only hate me. you can call any but you didnt accept even a single call of mine. or even you messaged	0
+250	ok going to sleep. hope i can meet her.	0
+251	nope i'm not drivin... i neva develop da photos lei...	0
+252	do u think that any girl will propose u today by seing ur bloody funky shit fucking face...............asssssholeeee................	0
+253	dont show yourself. how far. put new pictures up on facebook.	0
+254	k...k:)why cant you come here and search job:)	0
+255	i'm home...	0
+256	k. i will sent it again	0
+257	tired. i haven't slept well the past few nights.	0
+258	okay... i booked all already... including the one at bugis.	0
+259	;-( oh well, c u later	0
+260	did he say how fantastic i am by any chance, or anything need a bigger life lift as losing the will 2 live, do you think i would be the first person 2 die from n v q?	0
+261	he telling not to tell any one. if so treat for me hi hi hi	0
+262	k:)eng rocking in ashes:)	0
+263	cant believe i said so many things to you this morning when all i really wanted to say was good morning, i love you! have a beautiful morning. see you in the library later.	0
+264	will be out of class in a few hours. sorry	0
+265	ya it came a while ago	0
+266	babe !!! i miiiiiiissssssssss you ! i need you !!! i crave you !!! :-( ... geeee ... i'm so sad without you babe ... i love you ...	0
+267	ur tonexs subscription has been renewed and you have been charged ??4.50. you can choose 10 more polys this month. www.clubzed.co.uk *billing msg*	1
+268	stop the story. i've told him i've returned it and he's saying i should not re order it.	0
+269	ugh hopefully the asus ppl dont randomly do a reformat.	0
+270	so i'm doing a list of buyers.	0
+271	beautiful truth against gravity.. read carefully: \our heart feels light when someone is in it.. but it feels very heavy when someone leaves it..\" good night"	0
+272	i'll text now! all creepy like so he won't think that we forgot	0
+273	if i was i wasn't paying attention	0
+274	okies... i'll go yan jiu too... we can skip ard oso, go cine den go mrt one, blah blah blah...	0
+275	then wat r u doing now? busy wif work?	0
+276	i can probably come by, everybody's done around  &lt;#&gt;  right?	0
+277	he is a womdarfull actor	0
+278	hope you are having a good week. just checking in	0
+279	well i wasn't available as i washob nobbing with last night so they had to ask nickey platt instead of me!;	0
+280	in e msg jus now. u said thanks for gift.	0
+281	i got lousy sleep. i kept waking up every 2 hours to see if my cat wanted to come in. i worry about him when its cold :(	0
+282	awesome, that gonna be soon or later tonight?	0
+283	he also knows about lunch menu only da. . i know	0
+284	i'm done. i'm sorry. i hope your next space gives you everything you want. remember all the furniture is yours. if i'm not around when you move it, just lock all the locks and leave the key with jenne.	0
+285	what i'm saying is if you haven't explicitly told nora i know someone i'm probably just not gonna bother	0
+286	best msg: it's hard to be with a person, when u know that one more step foward will make u fall in love.. &amp; one step back can ruin ur friendship.. good night:-) ...	0
+287	i wish u were here. i feel so alone	0
+288	ur cash-balance is currently 500 pounds - to maximize ur cash-in now send cash to 86688 only 150p/msg. cc: 08708800282 hg/suite342/2lands row/w1j6hl	1
+289	so many people seems to be special at first sight, but only very few will remain special to you till your last sight.. maintain them till life ends.. sh!jas	0
+290	hcl chennai requires freshers for voice process.excellent english needed.salary upto  &lt;#&gt; .call ms.suman  &lt;#&gt;  for telephonic interview -via indyarocks.com	0
+291	lol yes. our friendship is hanging on a thread cause u won't buy stuff.	0
+292	i gotta collect da car at 6 lei.	0
+293	gud mrng dear hav a nice day	0
+294	sure, whenever you show the fuck up &gt;:(	0
+295	loosu go to hospital. de dont let it careless.	0
+296	r u over scratching it?	0
+297	by the way, 'rencontre' is to meet again. mountains dont....	0
+298	good morning, im suffering from fever and dysentry ..will not be able to come to office today.	0
+299	our prasanth ettans mother passed away last night. just pray for her and family.	0
+300	alright took the morphine. back in yo.	0
+301	gud gud..k, chikku tke care.. sleep well gud nyt	0
+302	at what time should i come tomorrow	0
+303	sorry i cant take your call right now. it so happens that there r 2waxsto do wat you want. she can come and ill get her medical insurance. and she'll be able to deliver and have basic care. i'm currently shopping for the right medical insurance for her. so just give me til friday morning. thats when i.ll see the major person that can guide me to the right insurance.	0
+304	ya very nice. . .be ready on thursday	0
+305	check out choose your babe videos @ sms.shsex.netun fgkslpopw fgkslpo	1
+306	what's your room number again? wanna make sure i'm knocking on the right door	0
+307	that's fine, i'll bitch at you about it later then	0
+308	sorry. you never hear unless you book it. one was kinda a joke--thet were really looking for skinny white girls. the other was one line--you can only do so much on camera with that. something like that they're casting on the look.	0
+309	in xam hall boy asked girl tell me the starting term for dis answer i can den manage on my own after lot of hesitation n lookin around silently she said the! intha ponnungale ipaditan;)	0
+310	hey what's up charles sorry about the late reply.	0
+311	we have sent jd for customer service cum accounts executive to ur mail id, for details contact us	0
+312	645	0
+313	a boy loved a gal. he propsd bt she didnt mind. he gv lv lttrs, bt her frnds threw thm. again d boy decided 2 aproach d gal , dt time a truck was speeding towards d gal. wn it was about 2 hit d girl,d boy ran like hell n saved her. she asked 'hw cn u run so fast?' d boy replied \boost is d secret of my energy\" n instantly d girl shouted \"our energy\" n thy lived happily 2gthr drinking boost evrydy moral of d story:- i hv free msgs:d;): gud ni8"	0
+314	hurt me... tease me... make me cry... but in the end of my life when i die plz keep one rose on my grave and say stupid i miss u.. have a nice day bslvyl	0
+315	we are pleased to inform that your application for airtel broadband is processed successfully. your installation will happen within 3 days.	0
+316	ok...	0
+317	had your mobile 11 months or more? u r entitled to update to the latest colour mobiles with camera for free! call the mobile update co free on 08002986030	1
+318	i have no money 4 steve mate! !	0
+319	we are supposed to meet to discuss abt our trip... thought xuhui told you? in the afternoon. thought we can go for lesson after that	0
+320	yeah so basically any time next week you can get away from your mom &amp; get up before 3	0
+321	ur cash-balance is currently 500 pounds - to maximize ur cash-in now send go to 86688 only 150p/msg. cc: 08718720201 po box 114/14 tcr/w1	1
+322	i send the print  outs da.	0
+323	free 1st week entry 2 textpod 4 a chance 2 win 40gb ipod or ??250 cash every wk. txt vpod to 81303 ts&cs www.textpod.net custcare 08712405020.	1
+324	hows the pain dear?y r u smiling?	0
+325	gonna let me know cos comes bak from holiday that day.  is coming. don't4get2text me  number.	0
+326	is xy going 4 e lunch?	0
+327	nobody can decide where to eat and dad wants chinese	0
+328	?? dun need to pick ur gf?	0
+329	yeah you should. i think you can use your gt atm now to register. not sure but if there's anyway i can help let me know. but when you do be sure you are ready.	0
+330	as a valued customer, i am pleased to advise you that following recent review of your mob no. you are awarded with a ??1500 bonus prize, call 09066364589	1
+331	i'm meeting darren...	0
+332	aaooooright are you at work?	0
+333	i think steyn surely get one wicket:)	0
+334	hurry home. soup is done!	0
+335	money!!! you r a lucky winner ! 2 claim your prize text money 2 88600 over ??1million to give away ! ppt150x3+normal text rate box403 w1t1jy	1
+336	auntie huai juan never pick up her phone	0
+337	sorry, i'll call later in meeting	0
+338	shuhui has bought ron's present it's a swatch watch...	0
+339	forgot you were working today! wanna chat, but things are ok so drop me a text when you're free / bored etc and i'll ring. hope all is well, nose essay and all xx	0
+340	i think it's all still in my car	0
+341	cool, text me when you head out	0
+342	i wanna watch that movie	0
+343	aslamalaikkum....insha allah tohar beeen muht albi mufti mahfuuz...meaning same here....	0
+344	he is world famamus....	0
+345	hello, hello, hi lou sorry it took so long 2 reply- i left mobile at friends in lancaster, just got it bak neway im sorry i couldn??t make ur b??day 2 hun!	0
+346	thank you princess! you are so sexy...	0
+347	so your telling me i coulda been your real valentine and i wasn't? u never pick me for nothing!!	0
+348	hey so this sat are we going for the intro pilates only? or the kickboxing too?	0
+349	so u gonna get deus ex?	0
+350	those cocksuckers. if it makes you feel better ipads are worthless garbage novelty items and you should feel bad for even wanting one	0
+351	sorry i've not gone to that place. i.ll do so tomorrow. really sorry.	0
+352	mmmmm ... i loved waking to your words this morning ! i miss you too, my love. i hope your day goes well and you are happy. i wait for us to be together again	0
+353	night has ended for another day, morning has come in a special way. may you smile like the sunny rays and leaves your worries at the blue blue bay.	0
+354	i will come with karnan car. please wait till 6pm will directly goto doctor.	0
+355	no need to say anything to me. i know i am an outsider	0
+356	are you happy baby ? are you alright ? did you take that job ? i hope your fine. i send you a kiss to make you smile from across the sea ... *kiss* *kiss*	0
+357	my sis is catching e show in e afternoon so i'm not watching w her. so c u wan 2 watch today or tmr lor.	0
+358	hey no i ad a crap nite was borin without ya 2 boggy with me u boring biatch! thanx but u wait til nxt time il ave ya	0
+359	quite late lar... ard 12 anyway i wun b drivin...	0
+360	does cinema plus drink appeal tomo? * is a fr thriller by director i like on at mac at 8.30.	0
+361	&lt;#&gt;  mins but i had to stop somewhere first.	0
+362	ok lor. msg me b4 u call.	0
+363	tunde, how are you doing. this is just wishing you a great day. abiola.	0
+364	hey i've booked the pilates and yoga lesson already... haha	0
+365	ok. no wahala. just remember that a friend in need ...	0
+366	private! your 2003 account statement for 07808 xxxxxx shows 800 un-redeemed s. i. m. points. call 08719899217 identifier code: 41685 expires 07/11/04	1
+367	ok i'm waliking ard now... do u wan me 2 buy anything go ur house?	0
+368	dear 0776xxxxxxx u've been invited to xchat. this is our final attempt to contact u! txt chat to 86688 150p/msgrcvdhg/suite342/2lands/row/w1j6hl ldn 18yrs	1
+369	wan2 win a meet+greet with westlife 4 u or a m8? they are currently on what tour? 1)unbreakable, 2)untamed, 3)unkempt. text 1,2 or 3 to 83049. cost 50p +std text	1
+370	you are a ??1000 winner or guaranteed caller prize, this is our final attempt to contact you! to claim call 09071517866 now! 150ppmpobox10183bhamb64xe	1
+371	5 nights...we nt staying at port step liao...too ex	0
+372	no..but heard abt tat..	0
+373	six chances to win cash! from 100 to 20,000 pounds txt> csh11 and send to 87575. cost 150p/day, 6days, 16+ tsandcs apply reply hl 4 info	1
+374	nope thats fine. i might have a nap tho!	0
+375	k.k.how is your business now?	0
+376	are you up for the challenge? i know i am :)	0
+377	urgent! your mobile number has been awarded with a ??2000 prize guaranteed. call 09061790126 from land line. claim 3030. valid 12hrs only 150ppm	1
+378	i've been trying to reach him without success	0
+379	hello which the site to download songs its urgent pls	0
+380	quite lor. but dun tell him wait he get complacent...	0
+381	what pa tell me.. i went to bath:-)	0
+382	we spend our days waiting for the ideal path to appear in front of us.. but what we forget is.. \paths are made by walking.. not by waiting..\" goodnight!"	0
+383	as i entered my cabin my pa said, '' happy b'day boss !!''. i felt special. she askd me 4 lunch. after lunch she invited me to her apartment. we went there.	0
+384	congratulations ur awarded either ??500 of cd gift vouchers & free entry 2 our ??100 weekly draw txt music to 87066 tncs www.ldew.com 1 win150ppmx3age16	1
+385	true. it is passable. and if you get a high score and apply for phd, you get 5years of salary. so it makes life easier.	0
+386	ok.	0
+387	jus finish watching tv... u?	0
+388	you all ready for * big day tomorrow?	0
+389	the evo. i just had to download flash. jealous?	0
+390	k:)all the best:)congrats...	0
+391	stupid auto correct on my phone	0
+392	you got called a tool?	0
+393	good luck! draw takes place 28th feb 06. good luck! for removal send stop to 87239 customer services 08708034412	1
+394	hello, my love. what are you doing? did you get to that interview today? are you you happy? are you being a good boy? do you think of me?are you missing me ?	0
+395	yeah but which is worse for i	0
+396	watching cartoon, listening music &amp; at eve had to go temple &amp; church.. what about u?	0
+397	i know you mood off today	0
+398	yeah why not, is the gang all ready	0
+399	am not working but am up to eyes in philosophy so will text u later when a bit more free for chat...	0
+400	full heat pa:-) i have applyed oil pa.	0
+401	super da:)good replacement for murali	0
+402	okay... we wait ah	0
+403	whore you are unbelievable.	0
+404	jay says that you're a double-faggot	0
+405	they have a thread on the wishlist section of the forums where ppl post nitro requests. start from the last page and collect from the bottom up.	0
+406	dip's cell dead. so i m coming with him. u better respond else we shall come back.	0
+407	thnx dude. u guys out 2nite?	0
+408	yes. rent is very expensive so its the way we save.	0
+409	so when you gonna get rimac access	0
+410	are you willing to go for aptitude class.	0
+411	pick ur fone up now u dumb?	0
+412	am also doing in cbe only. but have to pay.	0
+413	oh! shit, i thought that was your trip! loooooool ... that just makes so much more sense now ... *grins* and the sofa reference was ... the \sleep on a couch\" link you sent me ... wasn't that how you went on your trip ? oh ... and didn't your babe go with you for that celebration with your rents?"	0
+414	i'm not sure, i was just checking out what was happening around the area	0
+415	ok, be careful ! don't text and drive !	0
+416	sunshine quiz wkly q! win a top sony dvd player if u know which country liverpool played in mid week? txt ansr to 82277. ??1.50 sp:tyrone	1
+417	id have to check but there's only like 1 bowls worth left	0
+418	ah poop. looks like ill prob have to send in my laptop to get fixed cuz it has a gpu problem	0
+419	hi ya babe x u 4goten bout me?' scammers getting smart..though this is a regular vodafone no, if you respond you get further prem rate msg/subscription. other nos used also. beware!	1
+420	aight i've been set free, think you could text me blake's address? it occurs to me i'm not quite as sure what i'm doing as i thought i was	0
+421	k, makes sense, btw carlos is being difficult so you guys are gonna smoke while i go pick up the second batch and get gas	0
+422	nowadays people are notixiquating the laxinorficated opportunity for bambling of entropication.... have you ever oblisingately opted ur books for the masteriastering amplikater of fidalfication? it is very champlaxigating, i think it is atrocious.. wotz ur opinion???? junna	0
+423	me hungry buy some food good lei... but mum n yun dun wan juz buy a little bit...	0
+424	let me know if you need anything else. salad or desert or something... how many beers shall i get?	0
+425	s.s:)i thinl role is like sachin.just standing. others have to hit.	0
+426	i dont know oh. hopefully this month.	0
+427	first has she gained more than  &lt;#&gt; kg since she took in. second has she done the blood sugar tests. if she has and its ok and her blood pressure is within normal limits then no worries	0
+428	really? i crashed out cuddled on my sofa.	0
+429	free top ringtone -sub to weekly ringtone-get 1st week free-send subpoly to 81618-?3 per week-stop sms-08718727870	1
+430	hello, yeah i've just got out of the bath and need to do my hair so i'll come up when i'm done, yeah?	0
+431	can ?_ call me at 10:10 to make sure dat i've woken up...	0
+432	i plane to give on this month end.	0
+433	i'm gonna rip out my uterus.	0
+434	it will stop on itself. i however suggest she stays with someone that will be able to give ors for every stool.	0
+435	come by our room at some point so we can iron out the plan for this weekend	0
+436	twenty past five he said will this train have been to durham already or not coz i am in a reserved seat	0
+437	ahhh. work. i vaguely remember that! what does it feel like? lol	0
+438	sms auction - a brand new nokia 7250 is up 4 auction today! auction is free 2 join & take part! txt nokia to 86021 now! hg/suite342/2lands row/w1j6hl	1
+439	dude sux for snake. he got old and raiden got buff	0
+440	that would be great. we'll be at the guild. could meet on bristol road or somewhere - will get in touch over weekend. our plans take flight! have a good week	0
+441	no plans yet. what are you doing ?	0
+442	yup... how ?_ noe leh...	0
+443	in life when you face choices just toss a coin not becoz its settle the question but while the coin in the air u will know what your heart is hoping for. gudni8	0
+444	remember to ask alex about his pizza	0
+445	it's not that you make me cry. it's just that when all our stuff happens on top of everything else, it pushes me over the edge. you don't underdtand how often i cry over my sorry, sorry life.	0
+446	i'm aight. wat's happening on your side.	0
+447	what u talking bout early morning? it's almost noon where your at!	0
+448	text pass to 69669 to collect your polyphonic ringtones. normal gprs charges apply only. enjoy your tones	1
+449	lol enjoy role playing much?	0
+450	i am waiting machan. call me once you free.	0
+451	do you want a new video handset? 750 any time any network mins? unlimited text? camcorder? reply or call now 08000930705 for del sat am	1
+452	great. so should i send you my account number.	0
+453	then. you are eldest know.	0
+454	were somewhere on fredericksburg	0
+455	k, wat s tht incident?	0
+456	also are you bringing galileo or dobby	0
+457	nationwide auto centre (or something like that) on newport road. i liked them there	0
+458	sorry da thangam.it's my mistake.	0
+459	yes there were many sweets	0
+460	i love u 2 babe! r u sure everything is alrite. is he being an idiot? txt bak girlie	0
+461	mon okie lor... haha, best is cheap n gd food la, ex oso okie... depends on whether wana eat western or chinese food... den which u prefer...	0
+462	its going good...no problem..but still need little experience to understand american customer voice...	0
+463	claire here am havin borin time & am now alone u wanna cum over 2nite? chat now 09099725823 hope 2 c u luv claire xx calls??1/minmoremobsemspobox45po139wa	1
+464	can. dunno wat to get 4 her...	0
+465	free for 1st week! no1 nokia tone 4 ur mobile every week just txt nokia to 8077 get txting and tell ur mates. www.getzed.co.uk pobox 36504 w45wq 16+ norm150p/tone	1
+466	r we going with the  &lt;#&gt;  bus?	0
+467	hello, as per request from  &lt;#&gt;  rs.5 has been transfered to you	0
+468	although i told u dat i'm into baig face watches now but i really like e watch u gave cos it's fr u. thanx 4 everything dat u've done today, i'm touched...	0
+469	want explicit sex in 30 secs? ring 02073162414 now! costs 20p/min gsex pobox 2667 wc1n 3xx	1
+470	hello. they are going to the village pub at 8 so either come here or there accordingly. ok?	0
+471	v skint too but fancied few bevies.waz gona go meet &othrs in spoon but jst bin watchng planet earth&sofa is v comfey; if i dont make it hav gd night	0
+472	where can download clear movies. dvd copies.	0
+473	\speak only when you feel your words are better than the silence...\" gud mrng:-)"	0
+474	prabha..i'm soryda..realy..frm heart i'm sory	0
+475	my sister in law, hope you are having a great month. just saying hey. abiola	0
+476	bring tat cd don forget	0
+477	lol or i could just starve and lose a pound by the end of the day.	0
+478	faith makes things possible,hope makes things work,love makes things beautiful,may you have all three this christmas!merry christmas!	0
+479	if you're not in my car in an hour and a half i'm going apeshit	0
+480	&lt;#&gt;  is fast approaching. so, wish u a very happy new year happy sankranti happy republic day happy valentines day happy shivratri happy ugadi happy fools day happy may day happy independence day, happy friendship,mother,father,teachers,childrens day, &amp; happy birthday 4 u. happy ganesh festival happy dasara happy diwali happy christmas  &lt;#&gt;  good mornings afternoons, evenings nights. rememberi am the first to wishing u all these...your's raj	0
+481	hi juan. im coming home on fri hey. of course i expect a welcome party and lots of presents. ill phone u when i get back. loads of love nicky x x x x x x x x x	0
+482	may i call you later pls	0
+483	exactly. anyways how far. is jide her to study or just visiting	0
+484	awww dat is sweet! we can think of something to do he he! have a nice time tonight ill probably txt u later cos im lonely :( xxx.	0
+485	i got to video tape pple type in message lor. u so free wan 2 help me? hee... cos i noe u wan 2 watch infernal affairs so ask u along. asking shuhui oso.	0
+486	ok thanx... take care then...	0
+487	sorry da. i gone mad so many pending works what to do.	0
+488	hellogorgeous, hows u? my fone was on charge lst nitw wen u texd me. hopeu ad a nice wkend as im sure u did lookin 4ward 2 c-in u 2mrw luv jaz	0
+489	sir send to group mail check it.	0
+490	sorry, i'll call later in meeting	0
+491	i wanted to wish you a happy new year and i wanted to talk to you about some legal advice to do with when gary and i split but in person. i'll make a trip to ptbo for that. i hope everything is good with you babe and i love ya :)	0
+492	gibbs unsold.mike hussey	0
+493	ok lor...	0
+494	u studying in sch or going home? anyway i'll b going 2 sch later.	0
+495	here got lots of hair dresser fr china.	0
+496	as i entered my cabin my pa said, '' happy b'day boss !!''. i felt special. she askd me 4 lunch. after lunch she invited me to her apartment. we went there.	0
+497	easy mate, * guess the quick drink was bit ambitious.	0
+498	how much she payed. suganya.	0
+499	i???ve got some salt, you can rub it in my open wounds if you like!	0
diff --git a/data/split_multiwoz_data.py b/data/split_multiwoz_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb82b5f7b3d21819c63e266ef204fe9b58609540
--- /dev/null
+++ b/data/split_multiwoz_data.py
@@ -0,0 +1,65 @@
+# coding=utf-8
+#
+# Copyright 2020-2024 Heinrich Heine University Duesseldorf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import json
+import os
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--data_dir", default=None, type=str, required=True, help="Task database.")
+    args = parser.parse_args()
+
+    with open(os.path.join(args.data_dir, "data.json")) as f:
+        data = json.load(f)
+
+    val_list_file = os.path.join(args.data_dir, "valListFile.json")
+    if not os.path.isfile(val_list_file):
+        val_list_file = os.path.join(args.data_dir, "valListFile.txt")
+    with open(val_list_file) as f:
+        val_set = f.read().splitlines()
+
+    test_list_file = os.path.join(args.data_dir, "testListFile.json")
+    if not os.path.isfile(test_list_file):
+        test_list_file = os.path.join(args.data_dir, "testListFile.txt")
+    with open(test_list_file) as f:
+        test_set = f.read().splitlines()
+
+    val = {}
+    train = {}
+    test = {}
+
+    for k, v in data.items():
+        if k in val_set:
+            val[k] = v
+        elif k in test_set:
+            test[k] = v
+        else:
+            train[k] = v
+
+    print(len(data), len(train), len(val), len(test))
+
+    with open(os.path.join(args.data_dir, "train_dials.json"), "w+") as f:
+        f.write(json.dumps(train, indent = 4))
+
+    with open(os.path.join(args.data_dir, "val_dials.json"), "w+") as f:
+        f.write(json.dumps(val, indent = 4))
+
+    with open(os.path.join(args.data_dir, "test_dials.json"), "w+") as f:
+        f.write(json.dumps(test, indent = 4))
+
+if __name__ == "__main__":
+    main()
diff --git a/data/youtube/test.tsv b/data/youtube/test.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..7d12da0fe05deed218d1a7e6db6cc4947fb97529
--- /dev/null
+++ b/data/youtube/test.tsv
@@ -0,0 +1,251 @@
+index	sentence1	label
+0	Check out this video on YouTube:	1
+1	super music	0
+2	Subscribe my channel  I RECORDING FIFA 15 GOALS WATCH NOW :D	1
+3	This song is so beauty	0
+4	SEE SOME MORE SONG OPEN GOOGLE AND TYPE Shakira GuruOfMovie	1
+5	I like shakira..	0
+6	Hi. Check out and share our songs.	1
+7	I like this song 	0
+8	Hello everyone :) I know most of you probably pass up these kind of comments, but for those who are still reading this, thanks! I don’t have any money for advertisements, no chance of getting heard, nothing. I live in such a small town... If this comes off as spam, sorry. I’m an instrumental songwriter from Columbus, Mississippi. Please go to my channel and check out my original music. It would be highly appreciated if you thumbs up this comment so my music can be heard! Thank you, Adam Whitney 	1
+9	Subscribe to my Youtube Channel!! :) Suscribite a mi canal de Youtube -WhatUKnow	1
+10	 subscribe to my feed	1
+11	Awesome 	0
+12	i am from Brazil please subscribe my channel love you all	1
+13	Nice	0
+14	:) I&#39;ll subscribe to you. You look Nice :)	1
+15	Waka waka !!!	0
+16	 plz i wilsubscribe me frndzzl subscribe u back	1
+17	please subscribe to my page. thanks.	1
+18	Shakira is the best dancer	0
+19	Hey dickwad - we&#39;re all africans. The colour of your skin just tells us something about how long ago your ancestors left Africa.   Check out Baba Brinkman - he has a song called &quot;I&#39;m A African&quot;   Go learn something.	1
+20	This song is special, because is a song for Africa  and I am an African 	0
+21	O peoples of the earth, I have seen how you perform every form of evil at your leisure! You cease not from reveling in that which I hate! Behold, you murder the innocent day and night and plot evil against your neighbor! You stand up for the rights of those who commit abomination and clap your hands as wickedness is celebrated openly in the streets!... O MOST PERVERSE AND ABOMINABLE GENERATION, SHALL I NOT REPAY?!  Hear the Word of The Lord - TrumpetCallOfGodOnline.  co m	1
+22	Hey check out our new musicvideo &#39;&#39;Life&#39;s a celebration&#39;&#39; Peace !	1
+23	Love song	0
+24	i love u  shakira	0
+25	Oh my god go to 1 billion of replay i love shakira	0
+26	OMG LISTEN TO THIS ITS SOO GOOD!! :D	0
+27	I always have goose bumps at that part	0
+28	sexy shakira	0
+29	Shakira voice sound spanish but that is what make the music sound amazing ♥I LOVE IT	0
+30	subscribe  my	1
+31	The britishs called that soccer but you know the true word is football and the US Foot is called rugby. That works like that in EUW ( Germany/ Netherland/ France/ Portugal/ Belgium/ Spain/ ... ) Just because the Britishs don&#39;t agree to be like us because they are UNIQUE. That&#39;s the reason why they also refused to have euro (€) as money	1
+32	Very nice	0
+33	Pleas subscribe my channel GamezZMTA	1
+34	Shakira is very beautiful	0
+35	Thank you. Please give your email. 	1
+36	My uncle said he will stop smoking if this comment gets 500 likes! Please like this comment! Thanks.	1
+37	i remember this song!	0
+38	Hey, have you tried &quot;DribblePROshot&quot; yet? Just do a search on Google. On their website you will find a useful free video demonstrating the right way to enormously improve your soccer or football skills in no time... It transformed Fausto into a substantially better football/soccer player...to the amazement of his team mates. Hopefully it works for you too.	1
+39	I love this song because we sing it at Camp all the time!!	0
+40	Gusttavo Lima Você não me conhece <br />Check out	1
+41	Shakira I love you	0
+42	Hello everyone :) I know most of you probably pass up these kind of comments, but for those who are still reading this, thanks! I don’t have any money for advertisements, no chance of getting heard, nothing. I live in such a small town... If this comes off as spam, sorry. I’m an instrumental songwriter from Columbus, Mississippi. Please go to my channel and check out my original music. It would be highly appreciated if you thumbs up this comment so my music can be heard! Thank you, Adam Whitney 	1
+43	Wats Good Go Check Out My Music On Da Channel Ik Yall Got me	1
+44	Please.. Check my channel out:) I subscribe back..;)	1
+45	Hello Guys...I Found a Way to Make Money Online You Can Get Paid To Mess Around On Facebook And Twitter! GET PAID UPTO $25 to $35 AN HOUR...Only at 4NetJobs.com Work from the Comfort of your Home... They are Currently Hiring People from all Over the World, For a Wide Range of Social Media Jobs on Sites such as Facebook,Twitter and YouTube You don&#39;t Need any Prior Skills or Experience and You can Begin Work Immediately! You Can Easily Make $4000 to $5000+ Monthly Income…Only at 4NetJobs.com	1
+46	Help shakira&#39;s waka waka be the first song by a female artist to reach 1 billion views.<br />(dark horse is ahead by roughly 100 million more views, and roar has only 50 million more views)<br /><a href="https://www.youtube.com/watch?v=pRpeEdMmmQ0">https://www.youtube.com/watch?v=pRpeEdMmmQ0</a>	1
+47	awesome	0
+48	:)	0
+49	still watching in 2015...	0
+50	wow	0
+51	&lt;3 this song so much.SHAKIRA YOUR A REALLY GOOD ARTIST.	0
+52	She&#39;s so pretty	0
+53	**CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE****	1
+54	I love you	0
+55	Subscribe &amp; Like /watch?v=5tu9gN1l310	1
+56	subscribe to me and I wil subscribe to you back	1
+57	Yea stil the best WK song ever<br />Thumbs up of you think the same<br />	1
+58	wow	0
+59	subscribe now!!!!!! love the song!!!!!! love football!!!	1
+60	I want new song	0
+61	She is perfect! &lt;3	0
+62	THIS IS SHIT AND SOOOOO AUTOTUNED	0
+63	Hello Guys...I Found a Way to Make Money Online You Can Get Paid To Mess Around On Facebook And Twitter! GET PAID UPTO $25 to $35 AN HOUR...Only at 4NetJobs.com Work from the Comfort of your Home... They are Currently Hiring People from all Over the World, For a Wide Range of Social Media Jobs on Sites such as Facebook,Twitter and YouTube You don&#39;t Need any Prior Skills or Experience and You can Begin Work Immediately! You Can Easily Make $4000 to $5000+ Monthly Income…Only at 4NetJobs.com	1
+64	My friend Sam loves this song😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊😊	0
+65	Why so many disliked??????!!!!!!😯	0
+66	This Song will never get old	0
+67	Can this channel get 500+ subscribers? You can make that happen :D	1
+68	 I really can&#39;t comprehend Miley Cyrus , she actually is a high profile and she tapes herself banging Today a video was leeched with her sucking and fucking The video has been posted at the celebrity website under :  miley-celeb-news.co.uk 	1
+69	Subscribe and Win a CAP<br />       ☆☆☆☆☆	1
+70	I heard this when I was only 6 years old and I still love it	0
+71	OMG LOVE THIS!	0
+72	shakira is best for worldcup	0
+73	I loved this song when I was in my teenage years!	0
+74	cool	0
+75	Wanna Laugh??? Please SUBSCRIBE to our channel!!!	1
+76	Northland Paranormal Society is now on YouTube! Check out our channel of real Paranormal evidence!! facebook/ Northland Paranormal Society!!!	1
+77	Hey Music Fans I really appreciate any of you who will take the time to read this, and check my music out! I&#39;m just a 15 year old boy DREAMING of being a successful MUSICIAN in the music world. I do lots of covers, and piano covers. But I don&#39;t have money to advertise. A simple thumbs up to my comment, a comment on my videos or a SUBSCRIPTION would be a step forward! It will only be a few seconds of your life that you won&#39;t regret!!! Thank u to all the people who just give me a chance! :)	1
+78	Best World Cup Song	0
+79	I love you Shakiria!!!!!!	0
+80	wow	0
+81	....I stil lisening this :)	0
+82	Ouf Ouf OUFFFFFFFFFFFFFFFFFF!!!!!!!! :)	0
+83	Your the best♣♥	0
+84	good!!	0
+85	Check out my bass cover of hips don&#39;t lie by shakira!	1
+86	ILOVETHISSONG	0
+87	i totally love this song. absolutely fantastic! i love ya shakira!	0
+88	Hello everyone :) I know most of you probably pass up these kind of comments, but for those who are still reading this, thanks! I don’t have any money for advertisements, no chance of getting heard, nothing... If this comes off as spam, sorry. I am a video animator, just trying to make it up into the video animation industry. Please give me the chance to prove myself to you. Please visit my channel, subscribe if you like and thumb this comment up, so everyone can see! Thank You! 	1
+89	Shakira :-*	0
+90	Check out this playlist on YouTube:Central 	1
+91	Hey guys and girls check out Comedy Recipe for hilarious you tube videos, pranks, and crank calls!	1
+92	I subscribed it<br />	1
+93	I love shakira<br />❤❤❤❤	0
+94	this song sucks	0
+95	I love it	0
+96	Like	0
+97	her voice is so wow!	0
+98	If you could take time &amp; spare a min to read this, then thank you. <br /><br />Im a rapper, if we&#39;re gonna be honest theres thousands maybe even millions<br /><br />of those in the world but not too many with ambition, dedication &amp; passion.<br /><br />I don&#39;t have money for huge YouTube advertisements or incredible music videos<br /><br />so all im left with is comments to expose my music.<br /><br />If you could take a moment and give me a chance ill make you a believer<br /><br />I would love? nothing more than a loyal following on YouTube. For anyone who <br /><br />Reads this could you press the &quot;THUMBS UP&quot;  others will see it.<br /><br />Doing so will help me push closer to my dream :) Thank You	1
+99	wery good	0
+100	please visit our web: wellcomemd.blogfa.com	1
+101	God she is so hot	0
+102	Could Spanish people understand this?<br /><br />Any way&#39;s I how you doing subscribe to me I brake things<br /><br />-_-	1
+103	Waka waka she rules	0
+104	You best singer	0
+105	Check out this video on YouTube:<br />&quot;This Time for Africa&quot;. One for Trayvon!  And Our MOTHERland<br />Thank-you Shakira	1
+106	Subscribe me please. i&#39;ll promise i&#39;ll sub back	1
+107	cool song check out my animal trafficking petition.<br /><a href="http://www.thepetitionsite.com/387/433/550/stop-animal-trafficking/?cid=headerClick">http://www.thepetitionsite.com/387/433/550/stop-animal-trafficking/?cid=headerClick</a>	1
+108	Like	0
+109	Support the fight for your 4th amendment right to privacy in your home.  Stop the NSA spying on Americans with the un Patriot Act Renewal. Rand Paul has spent 10.5 hours on the Senate floor in a Protest and Filibuster fighting for our  Constitution that this Nation is founded on. Join the fight at Rand Paul dot com. Spread The Word. We Have Someone That Cares About Our Nation.  Email your Senators, Congress men and women, tell them to support Rand. Tell the news to support Rand too Senator Rand Paul was up until <a href="http://www.youtube.com/watch?v=pRpeEdMmmQ0&amp;t=1m00s">1:00</a> am this passed Saturday morning fighting for our Constitution buy postponing the vote until this week. Our Constitution Matters join Rand in the fight to protect <a href="http://it.ht">it.ht</a> to privacy in your home and business.. Senator Rand Paul was up until <a href="http://www.youtube.com/watch?v=pRpeEdMmmQ0&amp;t=1m00s">1:00</a> am this passed Saturday  morning fighting for our Constitution buy postponing the vote until this week. Our Constitution Matters To All Of US, Help Rand Protect It by joining the fright for it.	1
+110	Waka waka	0
+111	this song is racist	0
+112	WOw	0
+113	the best!	0
+114	Shakira	0
+115	Waka waka	0
+116	Shakira - Waka Waka <br />LOVE THIS SONG!!!!!!!!!!!!!!!	0
+117	Hello everyone my name&#39;s Anderson and i&#39;m a singer. not expecting to buy subscribers with words BUT to gain them with my voice. I might not be the best but my voice is different (in a good way) and i&#39;ll work harder than anyone out there to get better, &#39;cuz &quot;yeah&quot; i have a dream  a HUGE one, (who doesn&#39;t?) so please take 3 minutes of your time to check out my covers. Give me a chance you won&#39;t regret it If you feel like subscribing that&#39;d be awesome and it&#39;d mean the world to me THANK YOU SO MUCH	1
+118	I don&#39;t think this song will ever get old 	0
+119	Youtube comments be like<br />This is so 5 years ago. 	0
+120	Hey, I am doing the Forty Hour famine so I&#39;ll be giving up on food and social working for 40 hours. I&#39;m doing this to raise money for African people who can&#39;t experience the luxuries that we can. So can you donate to give them a chance?  Any amount would do :)  Click on the link and donate h t t p : / / 4 0 h f . c o m . a u / A n t h o n y L a m Thanks :)	1
+121	You guys should check out this EXTRAORDINARY website called ZONEPA.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at ZONEPA.COM ! Visit Zonepa.com and check it out! The plausible summer submits the behavior. When does the grass check the peaceful seat? The country strategizes the edge.	1
+122	im still watching in 2015	0
+123	Shakira :-*	0
+124	waka waka	0
+125	Waka best one	0
+126	I love this!!!!! This is one of my fave songs now and I just subscribed :) :) :)	1
+127	The best world cup song ever!!!!	0
+128	it pisses me off a bit that blank space has more views (868 mio) than this. meh 	0
+129	well done shakira	0
+130	SHAKIRA SONG WAKA WAKA	0
+131	PLEASSSSSSSSSSSSSSSS SUBSCRIBEEEEEEEEEE MY CHANNNNNNELLL PLZZ	1
+132	nice song	0
+133	Definitley the song for 2010 when im not listening to gypsy SHAKIRA ROCKS<br /><span class="proflinkWrapper"><span class="proflinkPrefix">+</span><a class="proflink" href="https://plus.google.com/101721377578919894134" oid="101721377578919894134">shakiraVEVO</a></span>	1
+134	Whose who are watching this in 2015.  LIKE!	0
+135	I absolutely adore watching football plus I’ve started earning income with out risk from claiming bonus deals. It’s a weird technique where you put money on something with one bookmakers and put money against it on Betfair. You acquire the bonus as income . A lad named Jim Vanstone is selecting the wagers free on his website Vanstone Secrets (Google it!). I have generated about 600 quid so far. And it’s free. I assume the bookmakers pay him to get new men and women, but it succeeds.	1
+136	You guys should check out this EXTRAORDINARY website called FIREPA.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at FIREPA.COM !   Visit FIREPA.COM and check it out!   Lake   . Duzafizz . Singlewave . Spourmo . Burder . Colorful . Claster . Incandescent . Ambitious . Winooze . Absorbing . Macabre . Crestboot . Boxium . Womanly . Tan . Ybuwyn . Forgetful . Pepelexa . Zealous	1
+137	Subscribe me, I will? subscribe you back!!!	1
+138	beautiful	0
+139	adf.ly / KlD3Y	1
+140	Hey youtubers... I really appreciate all of you who took the time, to read this, I am just a 19 year old boy who wants to be a successful musician in the music world. I dont have any money to advertise my channel, If you could just visit my channel, comment on my video or subscribe, that would be great.... It will only be few seconds of your life..... Thank u to all the people who just gave me a chance l really appreciate it  	1
+141	Shakira :-*	0
+142	Please check out and send to others Freedom and Justice are on the line!  Please Google:   Steven L. Reed Case Lands in Supreme Court---thanks!	1
+143	I love  Shakira !!!!!! ❤🎵🎶🎼🎸	0
+144	Hello everyone, It Is not my intention to spam and am truly sorry If anyone Is annoyed by this but just please hear me out.  I am a rapper, singer, music producer and a song writer and have been making music for a while now, eight years to be exact.  I enjoy making all types of music that anyone can listen to, doesn&#39;t have me talking about killing anyone and most importantly, focuses on the quality of the music.  Please check out my page, It&#39;s only a click away.  Thank you and have a nice day :)	1
+145	Wow...5 years<br />	0
+146	she is beautiful but it is not American!	0
+147	Good song:-)	0
+148	Check out this video on YouTube: 	1
+149	She is perfect	0
+150	so beutiful	0
+151	Very pleasant to hear, haha, good.	0
+152	Watching in 2015	0
+153	Hello everyone :) I know most of you probably pass up these kind of comments, but for those who are still reading this, thanks! I don’t have any money for advertisements, no chance of getting heard, nothing... If this comes off as spam, sorry. I am a video animator, just trying to make it up into the video animation industry. Please give me the chance to prove myself to you. Please visit my channel, subscribe if you like and thumb this comment up, so everyone can see! Thank You! 	1
+154	Shakira is my favourite singer. Wooooo	0
+155	Top three Shakira songs (my choice) <br /><br />1- Waka Waka (it&#39;s time for Africa)<br /><br />2- Can&#39;t remember to forget you <br /><br />3- Empire<br /><br />Like this comment if u like Shakira 	1
+156	I remember that torunament like it was today.	0
+157	Stop Wasting Up Your Time and  Get Paid To Mess Around On Facebook And Twitter!  GET PAID UPTO $25 to $35 AN HOUR... Only at 4NetJobs.com  Work from the Comfort of your Home... We are Currently Hiring People from all Over the World,  For a Wide Range of Social Media Jobs on Sites such as Facebook,Twitter and YouTube.  You don&#39;t Need any Prior Skills or Experience and You can Begin Work Immediately!  You Can Easily Make $4000 to $5000+ Monthly Income…Only at 4NetJobs.com	1
+158	BEAUTIFUL	0
+159	WAYS TO MAKE MONEY 50k Per Month Search google Now &gt;&gt; 9nl.me/make-money-without-investment-1	1
+160	do you want to make some easy money? check out my page tvcmcadavid.weebly . com dont miss out on this opportunity give thumbs up i would  apprecitate it.	1
+161	nice song	0
+162	You guys should check out this EXTRAORDINARY website called ZONEPA.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at ZONEPA.COM ! Visit Zonepa.com and check it out! The meat discusss the successful memory. How does the peaceful unit arbitrate the guide? The addition designs the worried loss.	1
+163	She&#39;s such an awesome entertainer. And pretty too! &lt;3 Shakira!	0
+164	Lip synch is terrible	0
+165	How To Make A Lot Of Money Fast	1
+166	Love this song! My soccer team made a cd for our couch with this song on it!	0
+167	check out my new video	1
+168	5 years later i still love this song <br />~Axy665	0
+169	Hi. Check out and share our songs.	1
+170	Hey, hit this shit up while yall can, they killed the versace remix. Just type in CGE &quot;Versace Freestyle (Get Money)&quot; Shot by Ja-Wan Gardner and help them reach 1 million views.	1
+171	I like	0
+172	Hey Music Fans I really appreciate all of you who take time to read this, and check my music out! I&#39;m just a 15 year old boy DREAMING of being a successful MUSICIAN in the music world. I do lots of covers, and piano covers. But I dont have money to advertise. A simple thumbs up to my comment, a comment on my videos or a SUBSCRIPTION would be a step forward! It will only be a few seconds of your life that u won&#39;t regret!!! Thank u to all the people who just give me a chance it means a lot! :)	1
+173	I liked<br />	0
+174	New way to make money easily and spending 20 minutes daily --&gt; <a href="https://www.paidverts.com/ref/Marius1533">https://www.paidverts.com/ref/Marius1533</a>	1
+175	Your a fucking bitch	0
+176	please visit my channel	1
+177	hahahahah ♥♥♥♥ :D like vines ?  Subscribe to me for daily vines	1
+178	Like	0
+179	I swear Shakira keeps getting more and more gorgeous! She definitely looks more gorgeous with her hair this way than super curly.	0
+180	Check out this video on YouTube:	1
+181	Hi -this is Johnny: 1. If You already know my music - thumb this up, because You found it this way, too. 2. If You want to hear original songs completely made by 1 person, continue reading:  I sing, write original music+lyrics &amp; play guitar, bass, drums &amp; keyboards. I&#39;m a 1-man-band. My music is completely independent &amp; 100% listener-supported. If You want to hear it &amp; if You&#39;re willing to help record new album - click on my name/picture.  Thank You &amp; enjoy the music - wish You awesome day!	1
+182	Please visit this Website: oldchat.tk	1
+183	waka waka:-):-):-)	0
+184	Believe that Jesus Christ is your savior for all your sins. If you truly believe in Jesus Christ to be your savior for all your sins then you will go to Heaven. If you believe in Jesus Christ then you are saved and you are in salvation and you have gained God’s righteousness. It matters not how much you have sinned in the past, in the present and especially in the future. Believe that Jesus Christ is your savior and you will go to Heaven forever and that is the whole truth. Spread the truth. 	1
+185	Waka Waka!:D   Check out my new HALLOWEEN VIDEO!:)	1
+186	Subscribe me, I will? subscribe you back!!!	1
+187	&quot;HELP THE HUMANITY WITH YOUR SIGN IN TO THIS LINK WITH YOUR WHOLEHEARTED SUPPORT IF YOU SAY &quot;NO&quot; TO DISCRIMINATION. “UNITED WE STAND “  WE WILL BRING THE CHANGE TOGETHER.  IMPOSSIBLE ITSELF SAYS I  M  POSSIBLE&quot;  YOU ARE THE WINNERS OF MY SUCCESS!CHEERS TO ALL MY LOVING BROTHERS AND SISTERS !  PLEASE SHARE THIS LINK ON FACEBOOK TO SUPPORT AGAINST DISCRIMINATION.  tinyurl(dot)com(slash)mxh2y77  FEAR NONE BUT GOD !!!	1
+188	Hey Music Fans I really appreciate any of you who will take the time to read this, and check my music out! I&#39;m just a 15 year old boy DREAMING of being a successful MUSICIAN in the music world. I do lots of covers, and piano covers. But I don&#39;t have money to advertise. A simple thumbs up to my comment, a comment on my videos or a SUBSCRIPTION would be a step forward! It will only be a few seconds of your life that you won&#39;t regret!!! Thank u to all the people who just give me a chance! :) 	1
+189	Nice to meet You - this is Johnny: 1. If You already know my music - thumb this up, because You found it this way, too. 2. If You want to hear original songs completely made by 1 person, continue reading:  I sing, write original music+lyrics &amp; play guitar, bass, drums &amp; keyboards. I&#39;m a 1-man-band. My music is completely independent &amp; 100% listener-supported. If You want to hear it &amp; if You&#39;re willing to help record new album - click on my name/picture.  Enjoy the music - wish You awesome day!	1
+190	Good times ...	0
+191	Shakira you are so beautiful. You are lovely, lively.. We love you.	0
+192	It was  cool   the best   song ever  	0
+193	i like and love so much people all friends..... I love Shakira ^^	0
+194	Nice vídeo shakira good	0
+195	******* Facebook is LAME and so 2004! Check out ------------ swagFriends com Make thousands of cool new friends everyday! Join this new movement!	1
+196	best song eva	0
+197	love Shakira!	0
+198	Hey youtubers... I really appreciate all of you who took the time, to read this, I am just a 19 year old boy who wants to be a successful musician in the music world. I dont have any money to advertise my channel, If you could just visit my channel, comment on my video or subscribe, that would be great.... It will only be few seconds of your life..... Thank u to all the people who just gave me a chance l really appreciate it  	1
+199	How could 108k people dislike this song or video	0
+200	Hi.. Everyone.. If anyone after real online work. I can help u. Earn lots of money. It&#39;s fun. It&#39;s real and affiliated company.. U not think u r working. It&#39;s easy and enjoyable. For more info contact me .. Neeru105@ gmail.com	1
+201	I love you	0
+202	Love itt and ppl check out my channel!!!	1
+203	beautiful	0
+204	Part 2. Holy Mary, pray for us Holy Mother of God, pray for us Holy Virgin of virgins, pray for us Mother of Christ, pray for us Mother of divine grace, pray for us Mother most pure, pray for us Mother most chaste, pray for us Mother inviolate, pray for us Mother undefiled, pray for us Mother most amiable, pray for us Mother most admirable, pray for us Mother of good counsel, pray for us Mother of our Creator, pray for us Mother of our Redeemer, pray for us 	1
+205	Why there are so many dislikes. This song is so... awesome. It sounds like we MUST STOP BE RACISTS!!! If I could, I would like it 1,000,000,000 times!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	0
+206	Shakira :-*	0
+207	CHECK OUT THE DUBSTEP VERSION	1
+208	love!!!!	0
+209	Please visit this Website: oldchat.tk	1
+210	Hi there, have you heard about DribbleProShot? Just do a search on Google. On their web site you can watch a smart free video featuring the best way to significantly boost your football aka soccer skills in no time... It turned Nick into a much better football or soccer player...His team mates were definitily amazed! I hope it will help you also...	1
+211	SO THEN HOW ARE YOU GOING TO CALL YOURSELF A INSTRUMENTAL SONGWRITER IF THERES NO SINGING THERES NO SONG TO WRITE!?!?! LOL.   YOU GOT ALOT TO LEARN KID BUT HEY DON&#39;T FORGET TO SUBSCRIBE!	1
+212	best song in world	0
+213	You guys should check out this EXTRAORDINARY website called ZONEPA.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at ZONEPA.COM !   Visit Zonepa.com and check it out!  Why does the statement conciliate the acidic stretch? The earth recognizes the money. When does the numberless number transport the trade?	1
+214	i can to make money	1
+215	Nice song ^_^	0
+216	I really ask nicely to view my vids:) I subscribe back..	1
+217	adf.ly /KlD3Y	1
+218	like me	0
+219	this song always gives me chills! :)	0
+220	I love this song	0
+221	and how many subscribers compared to her over a million	1
+222	Love you shkira	0
+223	Love it!!!!!!!!!!!!!!!!💜	0
+224	I want to see Shakira, not football :)	0
+225	HI! CHECK OUT OUR AWESOME COVERS! AND SAY WHAT YOU THINK!	1
+226	Lamest World Cup song ever! This time FOR Africa? You mean IN Africa. It wasn&#39;t a Live Aid event or something. She made it seem like a charity case for them instead of a proud moment. Where was Ricky Martin when you needed him! SMH	0
+227	  Haha , Miley Cyrus has done it once again  Today someone leeched a porno video with her on a celeb site   I believe the website link is : miley-celeb-news.co.uk in case you want to view it.... 	1
+228	I love song 	0
+229	see this<br /><a href="http://adf.ly">http://adf.ly</a> /1HmVtX	1
+230	Could you please check out my covers on my channel? I do covers like Adele, Kodaline, Imagine Dragons...and more. Please if you could spare a few minutes,  could you have a listen to one or two of my covers , Feel free to comment and subscribe :) Thank you! 	1
+231	so beutiful	0
+232	Hey Music Fans I really appreciate any of you who will take the time to read this, and check my music out! I&#39;m just a 15 year old boy DREAMING of being a successful MUSICIAN in the music world. I do lots of covers, and piano covers. But I don&#39;t have money to advertise. A simple thumbs up to my comment, a comment on my videos or a SUBSCRIPTION would be a step forward! It will only be a few seconds of your life that you won&#39;t regret!!! Thank u to all the people who just give me a chance! :) 	1
+233	I hope everyone is in good spirits I&#39;m a hard working student who&#39;s also a passionate singer I look foward to the day when I can make my own music to share But for now I&#39;ve just been doing covers. Check out my channel, I&#39;ve done Covers of Miley Cyrus, Imagine Dragons, Lana Del Rey, Drake, Macklemore, Pink and countless others.  Subscribe only if you want to. My goal isn&#39;t to become famous but to  inspire FYI this isn&#39;t spamming, everyone has a right to freedom of speech. Thanks 	1
+234	How did you know that people makes another account just for subscribing itself and liking??? :)	1
+235	******* Facebook is LAME and so 2004! Check out ------------ swagFriends com Make thousands of cool new friends everyday! Join this new movement!	1
+236	PLEASE CHECK OUT THIS VIDEO CALLED &quot;WE LOVE MIND MASTER IT&quot;, THANK U :)	1
+237	Check out this playlist on YouTube:	1
+238	I love this song so much &lt;3<br />Keep em&#39; coming!	0
+239	Love this song so much! One of my faves! Xxx	0
+240	Check out this playlist on YouTube:	1
+241	fave song	0
+242	CHECK OUT partyman318 FR GOOD TUNEZ!! :D	1
+243	I really love watching football and also I’ve started off making income with out financial risk from acquiring bonus deals. It’s this weird technique where you wager on something with one bookmakers and bet against it on Betfair. You secure the bonus as income . A chap named Jim Vanstone is finding the wagers free on his own website Vanstone Secrets (Google it!). I’ve made about 500 quid thus far. And it is cost-free. I guess the bookies pay him to obtain new consumers, yet this actually works.	1
+244	Subscribe to my channel :)  &lt;3	1
+245	Pleas subscribe my channel	1
+246	The best FIFA world cup song for sure.	0
+247	hey you ! check out the channel of Alvar Lake !!	1
+248	Hello Guys...I Found a Way to Make Money Online You Can Get Paid To Mess Around On Facebook And Twitter! GET PAID UPTO $25 to $35 AN HOUR...Only at 4NetJobs.com Work from the Comfort of your Home... They are Currently Hiring People from all Over the World, For a Wide Range of Social Media Jobs on Sites such as Facebook,Twitter and YouTube You don&#39;t Need any Prior Skills or Experience and You can Begin Work Immediately! You Can Easily Make $4000 to $5000+ Monthly Income…Only at 4NetJobs.com	1
+249	Beautiful song beautiful girl it works	0
diff --git a/data/youtube/tfidf_feats/test_feats.pickle.gz b/data/youtube/tfidf_feats/test_feats.pickle.gz
new file mode 100644
index 0000000000000000000000000000000000000000..af6ca70042508fead64d16dc4018cfcc315cedf3
Binary files /dev/null and b/data/youtube/tfidf_feats/test_feats.pickle.gz differ
diff --git a/data/youtube/tfidf_feats/train_feats.pickle.gz b/data/youtube/tfidf_feats/train_feats.pickle.gz
new file mode 100644
index 0000000000000000000000000000000000000000..557dc72613457165f28cc4332d14b1df099bf980
Binary files /dev/null and b/data/youtube/tfidf_feats/train_feats.pickle.gz differ
diff --git a/data/youtube/tfidf_feats/valid_feats.pickle.gz b/data/youtube/tfidf_feats/valid_feats.pickle.gz
new file mode 100644
index 0000000000000000000000000000000000000000..8c57337404543f50e5fbfffca8cd70d000e9245c
Binary files /dev/null and b/data/youtube/tfidf_feats/valid_feats.pickle.gz differ
diff --git a/data/youtube/train.tsv b/data/youtube/train.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..dfdbd23f977bb9445eca9bdbdd746ccc009a80fc
--- /dev/null
+++ b/data/youtube/train.tsv
@@ -0,0 +1,1587 @@
+index	sentence1	label
+0	pls http://www10.vakinha.com.br/VaquinhaE.aspx?e=313327 help me get vip gun  cross fire al	1
+1	if your like drones, plz subscribe to Kamal Tayara. He takes videos with  his drone that are absolutely beautiful.	1
+2	go here to check the views :3	0
+3	Came here to check the views, goodbye.	0
+4	i am 2,126,492,636 viewer :D	0
+5	https://www.facebook.com/teeLaLaLa	1
+6	imagine if this guy put adsense on with all these views... u could pay ur  morgage	0
+7	Follow me on Twitter @mscalifornia95	1
+8	Can we reach 3 billion views by December 2014? 	0
+9	Follow 4 Follow                           @ VaahidMustafic Like 4 Like 	1
+10	On 0:02 u can see the camera man on his glasses....	0
+11	2 billion views wow not even baby by justin beibs has that much he doesn't  deserve a capitalized name	0
+12	Hey guys please check out my new Google+ page it has many funny pictures,  FunnyTortsPics  https://plus.google.com/112720997191206369631/post	1
+13	 Once you have started reading do not stop. If you do not subscribe to me  within one day you and you're entire family will die so if you want to stay  alive subscribe right now.	1
+14	Plizz withing my channel 	1
+15	It's so hard, sad :( iThat little child Actor HWANG MINOO dancing very  active child is suffering from brain tumor, only 6 month left for him .Hard  to believe .. Keep praying everyone for our future superstar.  #StrongLittlePsY #Fighting  SHARE EVERYONE PRAYING FOR HIM http://ygunited.com/2014/11/08/little-psy-from-the-has-brain-tumor-6-months-left-to-live/ 	1
+16	i think about 100 millions of the views come from people who only wanted to  check the views	0
+17	What free gift cards? Go here  http://www.swagbucks.com/p/register?rb=13017194	1
+18	https://www.facebook.com/SchoolGeniusNITS/photos/ms.c.eJw9kVkOxDAMQm808h5z~;4sNjqP~_tHqBEuM69AQUp1Ih~_fPHgk5zLLsVdQv0ZUf0MB~;LnUJ~;ufTH4YoKfRxYts2zvrrp6qGtw67y~;L551h~;f7~_vlcZzRG8vGCTlPSD9ONGeWhj8~_GIbu~;S3lzMvY~;IQ2~;TwSfzz9WHn7JUSvHufpglQRZczl05fNPhaGeVb3x8yDmC6X~_~;jTcjnMho~;vfXWCjZyvWObihrnGx2ocjnG2PG1EvHXzyjD~_o3h~_RY6f57sPrnD2xV~;~_BzszZ~;8~-.bps.a.390875584405933/391725794320912/?type=1&amp;theater 	1
+19	What my gangnam style	0
+20	Loool nice song funny how no one understands (me) and we love it	0
+21	and u should.d check my channel and tell me what I should do next!	1
+22	If I get 100 subscribers, I will summon Freddy Mercury's ghost to whipe  from the face of earth One Direction and Miley Cirus.	1
+23	Does anyone here use gift cards like Amazon, itunes, psn, google play,  itunes, or any other gift cards? Then you'll be happy to know you can get  free gift card codes for free from an amazing site. Here is a $50 itunes  gift card code XXBB5TCZHM39HVZD	1
+24	Please friend read my book and repass: http://www.4shared.com/web/preview/pdf/CjFofTxeba?	1
+25	We pray for you Little Psy ♡	0
+26	Will this song ever reach 7 Billion Views?	0
+27	hey again if you guys wouldnt mind chacking out my rap give it like and il  giver 3 of your vids a like	1
+28	get GWAR to play 2015 superbowl  http://www.change.org/petitions/the-national-football-league-allow-gwar-to-perform-the-2015-super-bowl-halftime-show#share 	1
+29	Fantastic!	0
+30	Have you tried a new social network TSU? This new social network has a  special thing.You can share the posts as well as on fb and twitter and even  to'll get paid You can registr here:  https://www.tsu.co/WORLDWIDE_LIFE	1
+31	Hi there~I'm group leader of Angel, a rookie Korean pop group. We have four  members, Chanicka, Julie, Stephanie, and myself, Leah. Please feel free to  check out our channel and leave some feedback on our cover videos (:  criticism is welcome as we know we're not top notch singers so please come  leave some constructive feedback on our videos; we appreciate any chance to  improve before auditioning for a Korean management company. We plan on  auditioning for JYP, BigHit, Jellyfish, YG or SM. Thank you for taking time  out of your day to read this !	1
+32	What Can i say....This Song He Just Change The World Completely... So good job PSY... (and your girls are awesome :))) )	0
+33	SUBSCRIBE TO ME AND I'LL SUBSCRIBE TO YOU! (Must like - cZFcxsn0jnQ) 	1
+34	http://flipagram.com/f/LUkA1QMrhF	1
+35	Subscribe ME!	1
+36	NEW GOAL!   3,000,000!  Let's go for it!	0
+37	if i reach 100 subscribers i will go round in public pouring a bucket of  ice water over people and then running away acting like it wasn't me! like  so people can see!!	1
+38	just came to check the view count	0
+39	CHECK MY CHANNEL OUT PLEASE. I DO SINGING COVERS	1
+40	just came here to check the views :P	0
+41	Check out my dubstep song "Fireball", made with Fruity Loops. I really took  time in it.  /watch?v=telOA6RIO8o	1
+42	http://www.gofundme.com/gvr7xg	1
+43	If i reach 100 subscribers i will tazz my self and my friend	1
+44	subscribe to me for call of duty vids and give aways Goal-100 subs	1
+45	Please do buy these new Christmas shirts! You can buy at any time before  December 4th and they are sold worldwide! Don't miss out:  http://teespring.com/treechristmas	1
+46	Check out pivot animations in my channel	1
+47	Hey guys! Im a 12 yr old music producer. I make chiptunes and 8bit music.  It would be wonderful if you checked out some of my 8bit remixes! I even  have a gangnamstyle 8bit remix if you would like to check that out ;)  Thanks!!	1
+48	Check my channel please! And listen to the best music ever :P	1
+49	Subscribe to my channel 	1
+50	LoL	0
+51	Subscribe to me and I'll subscribe back!!!	1
+52	Why does a song like this have more views than Michael Jackson SMH	0
+53	If you pause at 1:39 at the last millisecond you can see that that chick is  about to laugh. Takes a few tries.	0
+54	if you like roblox minecraft world of warcraft gta5 mario suscribe to my  channel	1
+55	GANGMAN STY- *D-D-D-D-D-D--DROP THE BASS!!*	0
+56	WORLD RECORD YOUTUBE VIDEO VIEWS !!!!!! XD	0
+57	I hav absolutely no idea what he's saying. Is it even a language?	0
+58	please like : http://www.bubblews.com/news/9277547-peace-and-brotherhood	1
+59	http://hackfbaccountlive.com/?ref=5242575	1
+60	https://www.facebook.com/tofikmiedzynB/photos/a.1496273723978022.1073741828.1496241863981208/1498561870415874/?type=1&amp;theater 	1
+61	Subscribe to me plz plz plz plz plz plZ 	1
+62	People, here is a new network like FB...you register also free, the  difference is only that you get paid for sharing, commenting and liking  posts and so one...don't waste your time on fb for sharing and not being  paid!! Register here to make also money with your everyday posts!!  https://www.tsu.co/slema13 Wellcome to everyone! ;)	1
+63	WAT DA FUCK THIS THE MOST VIEWED VIDEO IN YOUTUBE!	0
+64	Discover a beautiful song of A young Moroccan     http://www.linkbucks.com/AcN2g	1
+65	How can this music video get 2 billion views while im the only one watching  here on earth?????? lol	0
+66	subscribe to me :) 	1
+67	Sub my channel!	1
+68	https://twitter.com/GBphotographyGB	1
+69	Why does this video have so many views? Because asian things are awesome and non-asian countries are jelly so they  try to learn from asia by looking at this video d:	0
+70	Pls follow this channel!! http://www.twitch.tv/sevadus	1
+71	everyone please come check our newest song in memories of Martin Luther  King Jr.	1
+72	PSY is a good guy	0
+73	https://www.facebook.com/eeccon/posts/733949243353321?comment_id=734237113324534&amp;offset=0&amp;total_comments=74   please like frigea marius gabriel comment :D	1
+74	PLEASE SUBSCRIBE ME!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	1
+75	subscribe to itz recaps and above diddle	1
+76	https://www.facebook.com/nicushorbboy add mee &lt;3 &lt;3	1
+77	http://www.ermail.pl/dolacz/V3VeYGIN CLICK  http://www.ermail.pl/dolacz/V3VeYGIN  http://www.ermail.pl/dolacz/V3VeYGIN  http://www.ermail.pl/dolacz/V3VeYGIN  http://www.ermail.pl/dolacz/V3VeYGIN  http://www.ermail.pl/dolacz/V3VeYGIN  http://www.ermail.pl/dolacz/V3VeYGIN	1
+78	MANY MEMORIES...........	0
+79	There is one video on my channel about my brother...	1
+80	Hey, check out my new website!! This site is about kids stuff. kidsmediausa  . com	1
+81	Please give us a chance and check out the new music video on our channel!  You won't be disappointed.	1
+82	CHECK MY CHANNEL FOR MY NEW SONG 'STATIC'!! YOU'LL LOVE IT!!	1
+83	We are an EDM apparel company dedicated to bringing you music inspired  designs. Our clothing is perfect for any rave or music festival. We have  NEON crop tops, tank tops, t-shirts, v-necks and accessories! follow us on  Facebook or on instagraml for free giveaways news and more!! visit our site  at OnCueApparel	1
+84	Subscribe to me for free Android games, apps.. 	1
+85	-----&gt;&gt;&gt;&gt;  https://www.facebook.com/video.php?v=10200253113705769&amp;set=vb.201470069872822&amp;type=3&amp;permPage=1  &lt;--------	1
+86	Sub to my channel visuelgamingzNL I sub back	1
+87	 I hate this song! 	0
+88	Please help me go to college guys! Thanks from the bottom of my heart.  https://www.indiegogo.com/projects/i-want-to-go-to-college--19/x/9082175	1
+89	plz check out fablife / welcome to fablife for diys and challenges so plz  subscribe thx!	1
+90	Add me here...https://www.facebook.com/TLouXmusic	1
+91	https://www.surveymonkey.com/s/CVHMKLT	1
+92	Huh, anyway check out this you[tube] channel: kobyoshi02	1
+93	http://www.guardalo.org/best-of-funny-cats-gatti-pazzi-e-divertenti-2013-5287/100000415527985/ 	1
+94	subscribe my chanel	1
+95	https://www.facebook.com/pages/Brew-Crew-2014/134470083389909 Like this  facebook-page! Chance to win an Iphone 5S!	1
+96	Show your AUBURN PRIDE HERE: http://www.teespring.com/tigermeathoodie	1
+97	Free my apps get 1m crdits ! Just click on the link and download a app and  done!! · Link: https://m.freemyapps.com/share/url/5af506e1	1
+98	I remember when everyone was obsessed with Gangnam Style 😗	0
+99	This video will get to 2 billion just because of people checking if it has  hit 2 billion yet.	0
+100	how is this shit still relevant 	0
+101	 Hey everyone!! I have just started my first YT channel i would be grateful  if some of you peoples could check out my first clip in BF4! and give me  some advice on how my video was and how i could improve it. ALSO be sure to  go check out the about to see what Im all about. Thanks for your time :) .  and to haters... You Hate, I WIN	1
+102	The Funny Thing Is That this song was made in 2009 but it took 2 years to  get to america.	0
+103	Why dafuq is a Korean song so big in the USA. Does that mean we support  Koreans? Last time I checked they wanted to bomb us. 	0
+104	People Who Say That "This Song Is Too Old Now, There's No Point Of  Listening To It" Suck. Just Stfu And Enjoy The Music. So, Your Mom Is Old  Too But You Still Listen To Her Right?....	0
+105	Follow me on twitter &amp; IG : __killuminati94	1
+106	how does this video have 2,127,322,484 views if there are only 7 million  people on earth?	0
+107	Just coming to check if people are still viewing this video. And  apparently, they still do.	0
+108	I wanted to know the name of the guy that dances at 00:58, anybody knows ?	0
+109	hi guys please my android photo editor download. thanks https://play.google.com/store/apps/details?id=com.butalabs.photo.editor	1
+110	Can anyone sub to my channel? :D	1
+111	Hahah, juyk! I allways laugh at the part 1:57.. LOL!	0
+112	Don't mind me, I'm just checking what the views are up to : )	0
+113	subscribe to my channel people :D	1
+114	watch?v=vtaRGgvGtWQ   Check this out .	1
+115	https://www.indiegogo.com/projects/cleaning-the-pan--2    please halp me  with my project	1
+116	The girl in the train who was dancing, her outfit was so fucking sexy, but  the huge turn-off was she lacked eyebrows D:	0
+117	sub me if you dont like the song	1
+118	This video is so cool, again and again!	0
+119	This has had over 2 billion views.  Holy shit.	0
+120	Great music anyway	0
+121	DISLIKE.. Now one knows REAL music - ex. Enimen 	0
+122	▬▬▬▬▬▬▬▬▬▬ஜ۩۞۩ஜ▬▬▬▬▬▬▬▬ DAMN THIS COMMENT IS FANCY ▬▬▬▬▬▬▬▬▬▬ஜ۩۞۩ஜ▬▬▬▬▬▬▬▬	0
+123	Hello! Do you like gaming, art videos, scientific experiments, tutorials,  lyrics videos, and much, much more of that? If you do please check out our  channel and subscribe to it, we've just started, but soon we hope we will  be able to cover all of our expectations... You can also check out what  we've got so far!	1
+124	CHECK OUT MY CHANNEL	1
+125	COME AND CHECK OUT MY NEW YOUTUBE CHHANEL, GOING TO BE POSTING DAILY!	1
+126	https://www.change.org/p/facebook-twitter-youtube-do-not-censor-julien-blanc 	1
+127	http://woobox.com/33gxrf/brt0u5 FREE CS GO!!!!	1
+128	5 milions comentars and 2 bilion views	0
+129	http://www.twitch.tv/tareko100 Follow him on twitch and enter the keyword  !5800 and you'll have a chance of winning a really nice and expensive gun  for csgo that you can sell on the steam market	1
+130	look at my channel i make minecraft pe lets play 	1
+131	Come and watch my video it is called the odowd crowd zombie movie part 1 	1
+132	need money?Enjoy https://www.tsu.co/emerson_zanol	1
+133	The first comment is chuck norrus ovbiously :D	0
+134	I'm watching this in 2014	0
+135	so crazy, over 2 billion views, not US, not Uk, its Korea republic, its  asia	0
+136	Admit it you just came here to check the number of viewers 	0
+137	How can this have 2 billion views when there's only me on the planet? LOL	0
+138	What is he saying?!?!?!?!?!?!?!?$? 	0
+139	if you like raw talent, raw lyrics, straight real hip hop Everyone check my newest sound  Dizzy X - Got the Juice (Prod by. Drugs the Model Citizen)   COMMENT TELL ME WHAT YOU THINK  DONT BE LAZY!!!!  - 1/7 Prophetz	1
+140	Still the best. :D	0
+141	Ching Ching ling long ding ring yaaaaaa Ganga sty FUCK YOU.	0
+142	Haha its so funny to see the salt of westerners that top views of youtube  goes to video they dont even understand, keep the salt up!	0
+143	Remove This video its wank	0
+144	We get it, you came here for the views... 	0
+145	Oppa! Yeah! Best Song!	0
+146	i turned it on mute as soon is i came on i just wanted to check the  views...	0
+147	this comment is wrong	0
+148	#2012bitches	0
+149	1 million dislikes!EPIC FAIL(ready for you fanboys)	0
+150	If I get 300 subscribers by tomorrow I'll do a epic Hunger Games Video! 	1
+151	Dear person reading this, You are beautiful and loving Have a great day	0
+152	Dance dance,,,,,Psy  http://www.reverbnation.com/msmarilynmiles	1
+153	Song name??	0
+154	WHY DOES THIS HAVE 2 BILLION VIEWS THIS SONG IS SO ANNOYING	0
+155	I think this is now a place to promote channels in the comment section lol.	0
+156	Like if you came here too see how many views this song has.	0
+157	 Follow me on Instagram. _chris_cz  	1
+158	Please subscribe to me	1
+159	thumbs up if u checked this video to see hw views it got	0
+160	SUB 4 SUB PLEASE LIKE THIS COMMENT I WANT A SUCCESFULL YOUTUBE SO PPLEASE LIKE THIS  COMMENT AND SUBSCRIBE IT ONLY TAKES 10 SECONDS PLEASE IF YOU SUBSCRIBE ILL  SUBSCRIBE BACK THANKS	1
+161	2 billion....Coming soon	0
+162	just checking the views	0
+163	2,000,000,000 out of 7,000,000,000 people in the would saw this video just  in 2 years and yeat i only get 2 words out of the hole song	0
+164	Im a RAPPER/SONGWRITER, check my video PLEASE..also subscribe for more  thanks :) tell me what you think.	1
+165	Hey subscribe to me	1
+166	psy=korean	0
+167	I dont even watch it anymore i just come here to check on 2 Billion or not	0
+168	Why the fuck this keeps updated? Comments :"5 minutes ago" Song: "2 years and  4 months ago"	0
+169	guys please subscribe me to help my channel grow please guys	1
+170	Please check out my vidios	1
+171	Incmedia.org where the truth meets you.	1
+172	check out "starlitnightsky" channel to see epic videos	1
+173	Hey I think I know what where dealing with here!!!! I have some theories of  how this could've gotten 2billion hits!! 1. This was mabey made in korea and its realy popular there so they were  stuck watching this over and over again. 2. Over 2billion people have access to the Internet, including youtube, and  the numbers are rising, by 2017 half of the populatoin will be connected. 3. Hackers In Korea may have loved it so much they rised it to 2billion  hits to make it more popular.  4. The song was featured in a just dance game, on multiple mp3s, and been  seen on concerts and even on new years eve event in 2012, so just by seeing  those you mabey adding more hits to this video. 5. You are complaining to much on how the heck this has 2b hits.	0
+174	im sorry for the spam but My name is Jenny. I go to high school where  everyone dresses fashionable but for me I don't because i need money to buy  cute clothes. I have low self esteem . I live with my dad. my mom passed  away when i was 6 so i don't really have a mother figure. I have 2 brothers  who is older than me. Since they are boys they get the attention while i  just be alone. I really want to wear pretty clothes like the girls in my  school and get a boyfriend. i just can't be my self. im very quite and shy  at school because i don't have the confidence in myself to talk to someone.  i did have one friend name Caroline but she moved away so now im alone. if  you could donate some money to me it would be great. i don't care about  expensive brand ill just shop at walmart because they have pretty clothes.  also i wanna get my nails done at a salon . i see alot of girls have these  french tips. i never had my nail did at a salon before i will really  appreciate if i can and get my hair curled too. http://www.gofundme.com/dressprettyonce thanks omg please.	1
+175	This song never gets old love it.	0
+176	It's been back for quite a while now.	0
+177	Justin bieber = gay 	0
+178	My videos are half way decent, check them out if you want.	1
+179	You know a song sucks dick when you need to use google translate to know  what the fuck its saying!	0
+180	Enough with the whole "how does this have two billion views if there's only  7 million on the planet" we get it. You're joking. It's not funny anymore.	0
+181	If the shitty Chinese Government didn't block YouTube over there, there'd  be close to 3 billion views right now. 	0
+182	PSY - GANGNAM STYLE (강남스타일) M/V: http://youtu.be/9bZkp7q19f0	0
+183	2 billion views, only 2 million shares	0
+184	2 Billion Views For This Piece Of Shit... ~ A R E ~ Y O U ~ K I D D I N G ~ M E ~	0
+185	EHI GUYS CAN YOU SUBSCRIBE IN MY CHANNEL? I AM A NEW YOUTUBER AND I PLAY  MINECRAFT THANKS GUYS!... SUBSCRIBE!	1
+186	Hello all 29.24% earth population of the world, hope your having a great  day :)	0
+187	I am so awesome and smart!!! Sucscribe to me!	1
+188	Hi guys my name is Dylan and I do IRL football videos I have 1030  subscribers and I think you guys would like my content so come check it out  and if you do subscribe!	1
+189	More... http://www.sunfrogshirts.com/Sunglass-World.html?24398	1
+190	I'm here to check the views.. holy shit	0
+191	follower please https://www.facebook.com/lists/161620527267482	1
+192	Suscribe my channel please	1
+193	SUPER!!! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	0
+194	PSY GOT LOTS  OF MONEY FROM YOUTUBE THAT HE GOT FROM 2 BILLION VIEWS THIS  IS THE MOST VIEWS IN THE WORLD :D	0
+195	https://www.tsu.co/KodysMan plz ^^	1
+196	Check my channel, please!	1
+197	http://hackfbaccountlive.com/?ref=4436607  psy news offıcal 	1
+198	Lol this youtuber (officialpsy) is getting so much money lol	0
+199	OMG 2/7 People watched this video because there are 7 billion people in the  world and 2 billion watched this	0
+200	Behold the most viewed youtube video in the history of ever	0
+201	Hey guys can you check my YouTube channel I know you hate comments like  this one but I promise if you check my videos it will be entertaining I do  Shotgun Montages,Ninja Defuse Montages and Trolling please guys can you  check them out and thanks have a good day!!!!!!!	1
+202	OPPA GANGNAM STYLE!!!	0
+203	Subscribe to me i subscribe back!!!! Plus i have a nice ass lol	1
+204	This is getting old.........	0
+205	I found out this song now	0
+206	It's so funny it's awesomeness lol aaaaaaa sexy lada😂	0
+207	....subscribe......  ......to my........  .....channel.......	1
+208	Wow this video is the most viewed youtube video.. second that comes Justin  bieber- baby SMH WHAT HAS THE WORLD COME TO	0
+209	Go to my channel if u want to see a fly getting burned alive	1
+210	Can somebody wake me up when we get to 3 billion views.	0
+211	they said this video are not deserve 2billion views , while they keep  visiting it to watch the viewer . 	0
+212	Plz subscribe to my channel and I will subscribe back xx	1
+213	Please help me go here http://www.gofundme.com/littlebrother	1
+214	this has so many views	0
+215	Still a very fun music video to watch! 	0
+216	I don't now why I'm watching this in 2014	0
+217	Most viewed video on youtube...daaaaaaaaaaannng those views can almost  dominate the entire...china...	0
+218	https://www.facebook.com/FUDAIRYQUEEN?pnref=story	1
+219	Hello! I'm kind of new to Youtube, And soon i'm soon going to be making  Launchpad Video's! :D I would really appreciate if i got some subs before i  started so that people can spot me easily! I dont really care about hate  comments so dont bother -_-	1
+220	OPPA &lt;3	0
+221	2:05. Hahahahah 	0
+222	The most watched video on YouTube is Psy’s “Gangnam Style”, with 2.1  billion views. PSY - GANGNAM STYLE (강남스타일) M/V	0
+223	YOUTUBE MONEY !!!!!!!!!!!!!!!!!!!!!!!	0
+224	For Christmas Song visit my channel! ;)	1
+225	subscribe like comment	1
+226	C'mon 3 billion views!!!!!!!!	0
+227	 Something to dance to, even if your sad JUST dance!!   PSY - GANGNAM STYLE (강남스타일) M/V: http://youtu.be/9bZkp7q19f0	0
+228	http://www.twitch.tv/zxlightsoutxz	1
+229	The population of world is more than 7 billion	0
+230	Hey guys can you check my channel out plz. I do mine craft videos. Let's  shoot for 20 subs	1
+231	https://www.tsu.co/Aseris get money here !	1
+232	why are they 5million comments when there is only 279.898 youtube Users.   5million fake account or PSY hacked youtube	0
+233	YouTube/codytolleson for awesome videos I'll subscribe back 	1
+234	How stupid humanity is	0
+235	Hey, join me on tsū, a publishing platform where I share my content now:  http://tsu.co/MarkusMairhofer	1
+236	969,210 dislikes like dislike themselves	0
+237	❤️ ❤️ ❤️ ❤️ ❤️❤️❤️❤️	0
+238	Wow 23 min ago	0
+239	You should check my channel for Funny VIDEOS!!	1
+240	WHATS UP EVERYONE!? :-) I Trying To Showcase My Talent To The World! I Have Over 3000 SUBSCRIBERS! I PROMISE! I Dont Suck! Please Spread My Covers Around, SUBSCRIBE &amp; Share! Thanks so much for all your support! Lucas Trigo -Stay Awesome! 	1
+241	i hate this music. fucking singer and every koean chainise ana US sucks me dick.	0
+242	http://tankionline.com#friend=cd92db3f4 great game check it out!	1
+243	The Guy in the yellow suit kinda looks like Jae-suk 	0
+244	1 millioon dislikesssssssssssssssssssssssssssssssss.............	0
+245	Lol...I dunno how this joke gets a lot of likes, but whatever. xD	0
+246	I still to this day wonder why this video is so popular ?? illuminati  confirmed ??	0
+247	how can there be 2.124.821.694 views, when im the only person alive after  the zombie apocalypse - greetings, spoderman :)	0
+248	Look at the pictures, if not difficult http://image2you.ru/48051/1340524/        http://image2you.ru/48051/1340523/          http://image2you.ru/48051/1340522/ http://image2you.ru/48051/1340521/             http://image2you.ru/48051/1340520/       http://image2you.ru/48051/1340519/  http://image2you.ru/48051/1340518/            http://image2you.ru/48051/1340517/          http://image2you.ru/48051/1340504/ http://image2you.ru/48051/1340503/              http://image2you.ru/48051/1340502/            http://image2you.ru/48051/1340500/ http://image2you.ru/48051/1340499/        http://image2you.ru/48051/1340494/             http://image2you.ru/48051/1340493/ http://image2you.ru/48051/1340492/        http://image2you.ru/48051/1340491/           http://image2you.ru/48051/1340490/ http://image2you.ru/48051/1340489/             http://image2you.ru/48051/1340488/	1
+249	Dumb Guy: Why is there 2 billion views when there are 7 million people on  earth??? Also, I know what 1+1 equals! 1+1=1! I am a smartie pants	0
+250	Mix - PSY - GANGNAM STYLE (강남스타일) M/V: PSY - GANGNAM STYLE (강남스타일) M/V	0
+251	2 BILLION!!!	0
+252	Limit sun exposure while driving. Eliminate the hassle of having to swing  the car visor between the windshield and window.  https://www.kickstarter.com/projects/733634264/visortwin	1
+253	when is this gonna hit 2 billion?	0
+254	This is a weird video.	0
+255	Search "Chubbz Dinero - Ready Or Not " Thanks 	1
+256	check men out i put allot of effort into my music but unfortunatly not many  watch it	1
+257	http://binbox.io/1FIRo#123	1
+258	www.marketglory.com/strategygame/lordviperas	1
+259	How did THIS Video in all of YouTube get this many views and likes? Why  Gangnam style? I don't have a problem with it, i just don't understand the  phenomena behind it, it's just like any other random music video out  there. 	0
+260	Im just to check how much views it has	0
+261	Hey guys! Check this out: Kollektivet - Don't be slappin' my penis!  I  think that they deserve much more credit than they receive.	1
+262	now its 1,884,034,783 views! pls. comment the view count the next hour :P	0
+263	Suscribe My Channel Please XD lol	1
+264	8 million likes xD even the subscribers not 8 million xD	0
+265	You gotta say its funny. well not 2 billion worth funny but still. It  clicked and everything went uphill. At least you don't have JB's shit on  #1.	0
+266	I'm only checking the views	0
+267	How are there 2 billion views and theres only 2 million people in the  world!?!?!?!! MULTIPLE ACCOUNTS!!!1111	0
+268	Check out my Music Videos! and PLEASE SUBSCRIBE!!!! Fuego - U LA LA Remix  hyperurl.co/k6a5xt	1
+269	Made in china....	0
+270	Check me out! I'm kyle. I rap so yeah 	1
+271	http://www.gcmforex.com/partners/aw.aspx?Task=JoinT2&amp;AffiliateID=9107	1
+272	OMG over 2 billion views!	0
+273	please like :D https://premium.easypromosapp.com/voteme/19924/616375350	1
+274	Wow. Comments section on this still active. Not bad. Also 5277478 comments.  (Now 79)	0
+275	9 year olds be like, 'How does this have 2 billion views when there are  only 3 people in the world'	0
+276	THIS HAS MORE VIEWS THAN QUEEN AND MICHAEL JACKSON, 2 BILLION views omg	0
+277	Is this the video that started the whole "got my dick stuck in an elevator"  excuse thing? 	0
+278	why I dont see any comments but mine?:/	0
+279	Hey everyone, I am a new channel and will post videos of book reviews and  music on the flute. Please subscribe if you would enjoy that. Thanks!	1
+280	If I knew Korean, this would be even funnier. At least a bit at the end was  in English, but was spoken quite rapidly.	0
+281	Screw this Chinese crap i dont even understand what he is saying. Why isn't  he speaking English like everyone should?	0
+282	Check me out I'm all about gaming 	1
+283	Stupid people... this video doesnt have 2 billion visits. Have 2 thousands  millions	0
+284	OMG this oldspice spraytan party commercial omg....i'm sitting here "NO  this isn't a real thing is it? OMG" 	0
+285	me shaking my sexy ass on my channel enjoy ^_^ 	1
+286	hi guys check my youtube channel	1
+287	LOL this shit never gets old	0
+288	i check back often to help reach 2x10^9 views and I avoid watching Baby	0
+289	Hey guys check out my new channel and our first vid THIS IS US THE  MONKEYS!!! I'm the monkey in the white shirt,please leave a like comment  and please subscribe!!!!	1
+290	http://ubuntuone.com/40beUutVu2ZKxK4uTgPZ8K	1
+291	https://www.facebook.com/pages/Mathster-WP/1495323920744243?ref=hl	1
+292	marketglory . com/strategygame/andrijamatf earn real money from game	1
+293	The funny thing is, 1,700,000,000 of the views are spam bots. I mean c'mon  2 BILLION views? BS!	0
+294	most viewed video in the world	0
+295	http://www.twitch.tv/jaroadc come follow and watch my stream!	1
+296	Like getting Gift cards..but hate spending the cash.... Try Juno Wallet !!! At Juno Wallet you can earn money for gift cards such as ; Nike, Gamestop,  Amazon , Ebay Etc &amp; its easy  Earn money by doing simple task like watching videos..downloading apps &amp;  you can even earn money by inviting your friends to join...its free for  signup Sign up today &amp; use promo code BD3721315	1
+297	this jap is such a piece of shit. he is such a worthless fish head. i dont  know how any one likes this dumb untanlted gook. this isnt even fucken  music. this is so fucking sad that this is even such thing. people are so  fucked up.	0
+298	With the korean girl more slut and bitch : Hyuna :'33	0
+299	The little PSY is suffering Brain Tumor and only has 6 more months to live.  Please pray to him and the best lucks.	0
+300	FOLLOW MY COMPANY ON TWITTER  thanks.  https://twitter.com/TheWaxedHatCo	1
+301	reminds me of this song https://soundcloud.com/popaegis/wrenn-almond-eyes	1
+302	2 Billions in 2014	0
+303	Check out my Music Videos! Fuego - U LA LA Remix  hyperurl.co/k6a5xt	1
+304	😫😓😏😪😔😖😌😭😎😚😘😙😗😋😝😜😛😍😒😞😷😶😵😳😲😱😟😰😩😨😧😦😥😤😣😮😴😢😡😠😬😕😑😐😯😉😈😇😆😅😄😃😂😁😀😊☺  every single types of face on earth	0
+305	Hey come check us out were new on youtube let us know what you think and  don't forget to subscribe thanks.	1
+306	For all of the little kidz out there there is Like 7 to 8 Billon people on  earth NOT 7 to 8 MILLON.Get you facts straight before posting comments.	0
+307	Get free gift cards and pay pal money!	1
+308	2.126.521.750  views!!!!!!!!!!!!!!!!!	0
+309	This song is great there are 2,127,315,950 views wow	0
+310	the most viewed youtube video of all time?	0
+311	everyones back lool this is almost 3 years old and people are still hear!  xD	0
+312	http://www.amazon.co.uk/gp/offer-listing/B00ECVF93G/sr=8-2/qid=1415297812/ref=olp_tab_refurbished?ie=UTF8&amp;condition=refurbished&amp;qid=1415297812&amp;sr=8-2 	1
+313	http://hackfbaccountlive.com/?ref=4604617	1
+314	2 billion for this shit?	0
+315	It is 0 zero	0
+316	The most liked video on YouTube...	0
+317	This is the best, funny and viral video of history (youtube) THE TRUE	0
+318	http://thepiratebay.se/torrent/6381501/Timothy_Sykes_Collection	1
+319	Ahhh, 2 years ago....	0
+320	https://soundcloud.com/jackal-and-james/wrap-up-the-night	1
+321	To everyone joking about how he hacked to get 2 billion views because  there's a certain amount of people or whatever,  He actually did buy views.	0
+322	Check my channel	1
+323	The first billion viewed this because they thought it was really cool, the  other billion and a half came to see how stupid the first billion were...	0
+324	The projects After Effects, Music, Foto, Web sites and another you can find  and buy here  http://audiojungle.net/user/EugeneKalinin/portfolio?ref=EugeneKalinin	1
+325	just for test I have to say murdev.com	1
+326	Please check out my vidios guys	1
+327	please throw a sub on my channel	1
+328	This is the only video on youtube that get so much views just because we  want to see how much views it has. 1.000.000 every day, I mean, Most people  think a video is popular when it actually gets 1.000.000 views.	0
+329	P E A C E  &amp;  L O V E  ! !	0
+330	http://www.bing.com/explore/rewards?PUBL=REFERAFRIEND&amp;CREA=RAW&amp;rrid=_0f9fa8aa-243a-5c2f-c349-ede05ea397ca Bing rewards, earn free money. AND NO U CANT GET UR VIRUS IN BLUE!	1
+331	I am now going to voyage to the first comment...      Tell my family I loved them. 😢	0
+332	Hi everyone! Do you like music? Then why not check out my music channel.  The LEXIS band will be uploading their own songs and covers soon so don't  miss out. Please SUBSCRIBE too as it does help us out a lot. Just takes one  click. -&gt;	1
+333	prehistoric song..has been	0
+334	sub my channel for no reason -_-	1
+335	http://www.ebay.com/itm/131338190916?ssPageName=STRK:MESELX:IT&amp;_trksid=p3984.m1555.l2649 	1
+336	http://www.avaaz.org/po/petition/Youtube_Corporation_Fox_Broadcasting_Company_Anular_os_strikes_no_Canal_Nostalgia/?cXPZpgb 	1
+337	PSY - GANGNAM STYLE (강남스타일) M/V: http://youtu.be/9bZkp7q19f0	0
+338	gofundme.com/grwmps	1
+339	I made a gaming channel (Unique right?) :L Angry Minecraft!	1
+340	You think you're smart?        Headbutt your face.	0
+341	Still watching this 2 years later? 	0
+342	Subscribe and like to me for more how to videos on minecraft!	1
+343	2,124923004 wiews... wow	0
+344	please subscribe i am a new youtuber and need help please subscribe and i  will subscribe back :D hoppa HOPPA GaNgAm StYlE	1
+345	Check my first video out	1
+346	http://www.ebay.com/itm/171183229277?ssPageName=STRK:MESELX:IT&amp;_trksid=p3984.m1555.l2649 	1
+347	''Little Psy, only 5 months left.. Tumor in the head :( WE WILL MISS U &lt;3	0
+348	Anybody who subscribes to me will get 10 subscribers	1
+349	I think he was drunk during this :) x)	0
+350	Katycat! https://m.facebook.com/profile.php?id=1461302180794905	1
+351	ROAAAAARRRRRR 🐯🐯🐯	0
+352	And after the video ends, a 13 ft. boa constrictor squeezes her to death.	0
+353	&lt;script&gt;document.write('&lt;a target="_self" href=" http://rover.ebay.com/rover/1/710-53481-19255-0/1?icep_ff3=1&amp;pub=5575096797&amp;toolid=10001&amp;campid=5337555197&amp;customid=bogdan+grigore&amp;ipn=psmain&amp;icep_vectorid=229508&amp;kwid=902099&amp;mtid=824&amp;kw=lg"&gt;check  this out new arive on ebay&lt;/a&gt;&lt;img  style="text-decoration:none;border:0;padding:0;margin:0;" src=" http://rover.ebay.com/roverimp/1/710-53481-19255-0/1?ff3=1&amp;pub=5575096797&amp;toolid=10001&amp;campid=5337555197&amp;customid=bogdan+grigore&amp;mpt='+Math.floor(Math.random()*999999999)+'"&gt;');&lt;/script&gt;&lt;noscript&gt;&lt;a  target="_self" href=" http://rover.ebay.com/rover/1/710-53481-19255-0/1?icep_ff3=1&amp;pub=5575096797&amp;toolid=10001&amp;campid=5337555197&amp;customid=bogdan+grigore&amp;ipn=psmain&amp;icep_vectorid=229508&amp;kwid=902099&amp;mtid=824&amp;kw=lg"&gt;check  this out new arive on ebay&lt;/a&gt;&lt;img  style="text-decoration:none;border:0;padding:0;margin:0;" src=" http://rover.ebay.com/roverimp/1/710-53481-19255-0/1?ff3=1&amp;pub=5575096797&amp;toolid=10001&amp;campid=5337555197&amp;customid=bogdan+grigore&amp;mpt=[CACHEBUSTER] "&gt;&lt;/noscript&gt;	1
+354	"eye of the tiger" "i am the champion" seems like katy perry is using  titles of old rock songs for lyrics..	0
+355	why the elephant have a broken horn	0
+356	nice ..very nice	0
+357	subscribe please 	1
+358	I make guitar covers, please have a look at my channel	1
+359	I can't comprehend why this video has nearly 700,000,000 views. Some people  found 'Gangnam Style' funny so that explains its high view count but this  is just an awful pop song. I really have no clue on this one.	0
+360	;-)	0
+361	I LOVE YOU!!! Because u make me smile when im sad you cheer me up with your  beautiful songs (: &lt;3	0
+362	How can this song have 629 million views while there are only 7 million  people in the world?	0
+363	Hey guys! My mom said if i got 100 subs before christmas I'll get a kitten  and I always have wanted an kitten so please watch my videos and if you  like it subscribe or like :) Please no mean comments	1
+364	hey guys i really want to go to a katy perry concert so im in this contest  where i need a lot of likes and im loosing :( so please if you could like  this picture it would be very nice, thank you i really want to go to her  concert :) https://www.facebook.com/exagdl/photos/a.936868579660284.1073741943.111719098841907/937732262907249/?type=1&amp;theater 	1
+365	:: ATTENTION : WE NEED 10 Million Views More for FIREWORK to Reach 500M .. We have only 1 and half day left for katy's birtgday. Listen How it could be possible?? &gt;&gt;JUST Open different Tabs from Different Browser &gt;&gt;Dont Watch Full Video, Remember We dont have time on hand, Its Time Wasting. View it only for 30 sec. Plz thumbs up and share	1
+366	http://www.rtbf.be/tv/emission/detail_the-voice-belgique/toutes-les-auditions/auditionDetail_?emissionId=3873&amp;id=342  Please join me to the voice Liked and shared it please to win more audition  score. Thanks so much	1
+367	so cute that monkey *-*! 	0
+368	And somehow she has access to makeup in the middle of the woods...	0
+369	 Facebook account HACK!! http://hackfbaccountlive.com/?ref=4477063	1
+370	Yesterday this video have 1 million likes and now it has 2 million	0
+371	http://www.twitch.tv/daconnormc	1
+372	Οh my god ... Roar is the most liked video at Vevo .. while 2 months ago  was Justin's Baby.. congrats Katy . Applause &lt;3 	0
+373	There is 7 bilion poeple on earth Now stop being dumb	0
+374	follow me on twitter: freyacumqueen	1
+375	Best song ever 	0
+376	Hey guys, I was working last night on this project, it's a music streamer  like spotify, but it allows you to download the files to listen to when  you're offline. Opening it as a free beta, feel free to try it out :) download: https://mega.co.nz/#!ehVjzKyA!5bGKg2iWPHZOHWAEUesrWnegqG_lABcO7Rw9WFX8hAo	1
+377	most popular video on youtube  	0
+378	YAY IM THE 11TH COMMENTER!!!!!                                    IF YOUR  CRAZY PLEASE DONATE BITCOINS TO ME AT: 1FHeNqx1twqM153v2PTAyieJNEALAhZvEM	1
+379	https://www.indiegogo.com/projects/help-stop-my-poverty-cyber-pan-handleing/x/8692160#home 	1
+380	Katy Perry - Roar (Official): http://youtu.be/CevxZvSJLk8	0
+381	I    loved        it           so       much          because         you          get         to         stand            fear .	0
+382	HONESTLY, I WANNA SEE YOU BE BRAVE! oh wait...	0
+383	Thanks to this video we know that Nokia phones can survive a plane crash.	0
+384	Fantastic!!!	0
+385	http://shhort.com/a?r=Jt2ufxHxc	1
+386	Katy Perry - Roar (Official): http://youtu.be/CevxZvSJLk8	0
+387	https://soundcloud.com/j-supt-fils-du-son/fucking-hostile	1
+388	I love this song so much!:-D I've heard it so much I pretty much memorized  the lyrics	0
+389	I love roar and Katy Perry. She is my favorite singer and sometimes I just  mute the volume and look at her!	0
+390	People, here is a new network like FB...you register also free, the  difference is only that you get paid for sharing, commenting and liking  posts and so one...don't waste your time on fb for sharing and not being  paid!! Register here to make also money with your everyday posts!!  https://www.tsu.co/slema13 Wellcome to everyone! ;)	1
+391	https://www.facebook.com/pages/%D8%AA%D8%AD%D9%85%D9%8A%D9%84-%D8%A7%D8%AC%D9%85%D9%84-%D8%A7%D9%84%D8%A7%D9%85%D9%88%D8%B3%D9%8A%D9%82%D9%89___-music/674732645945877 	1
+392	https://vimeo.com/107297364	1
+393	This song is the most boring, asinine song I have ever heard. Coming from a  guy who liked "Teenage Dream" (the song). Ever since Katy Perry's divorce  it's wave after wave of shitty but successful songs... For those of you  saying "This is the next wave of pop music". Stick it up your ass. Katy  Perry is like Silly Putty. 	0
+394	KATY PERRY, I AM THE "DÉCIO CABELO", "DECIO HAIR". I AM 60 YEARS OF AGE. I  DON"T HAVE FAMILY. I"M SINGLE. ALONE. HOMELESS. I WAS AN ALCOHOLIC: 15 AT  THE AGE OF 46. I AM AN INVISIBLE COMPOSER. MY DREAM IS TO RECORD MY SONGS.  COULD YOU HELP ME? PLEASE! PLEASE! I TRUST THAT THE YOU WILL GIVE ME A  CHANCE. I HAVE 109 VIDEOS IN THE YOUTUBE: deciocabelo canal. KATY PERRY, I  WAS BORN IN OCTOBER 25, TOO. THANK YOU VERY MUCH!!! DECIO HAIR.	1
+395	SIMPLY PUT, OUR CUSTOMERS LOVE US... and so will you!https://www.facebook.com/greenleader	1
+396	--&gt;ATTENTION KATYCATS! Katy leads with 7 noms at this year MTV EMA! do vote daily for her: tv.  mtvema. com/vote (delete spaces) thumbs up and share on fb	1
+397	Hiya😊 I just started YouTube and it would mean a lot if some of you could  subscribe and watch my first video?xx	1
+398	In what South American jungle or any jungle for that matter do you find an  African elephant and a bengal tiger in the same place. Who made this video.	0
+399	Subscribe pleaaaase to my instagram account , i'll subscribe back ♥  http://instagram.com/cyrine_ghorbel	1
+400	She kinda let herself go, huh?	0
+401	Hi everyone! My dream is to have passionate sex with Katy Perry. Some people do not believe in me…but maybe you will. Sex is a giant part of my life and it pushes me to be a better person! Please give me a chance by giving me a thumbs up! If you love my ambition please SUBSCRIBE!	1
+402	Almost 1 Bil. What? Wow, GS sucks, in my opinion.	0
+403	we all love you Katy Perry &lt;3	0
+404	even without make up she is still  hot     http://uk.eonline.com/eol_images/Entire_Site/2012430/reg_1024.katy.mkup.mh.053012.jpg 	1
+405	Katy has a voice of an Angel	0
+406	great song, but we all know that Katy buys her views..	0
+407	i like this song the video goes perfect with it	0
+408	How old is Katy perry?	0
+409	Check out my acoustic channel 	1
+410	Come check out our parody of this!	1
+411	Every time I watch this mv I just so so so glad that I live in a world that  don't have to worry about running from a real, human eating tiger.	0
+412	i was playing this song and the baby in my belly started to dance to  it.....sooooo cute, but had to take the song out cause of copy right and  add in a youtube made song....still sooo cute the baby danced to this song  LMFAO!!! https://plus.google.com/111982027348137311818/posts/g2JVinPueMo	1
+413	this song is so addicting. the hook is dope and catchy. love the video too.  I'm getting popular fast because i rap real.. thumbs up if you piss next to  the water in the toilet so its quiet................................	1
+414	She loves Vena. trojmiasto.pl/Vena-Bus-Taxi-o59253.html	1
+415	Rap from Belarus, check my channel:)	1
+416	Katy Perry - Roar (Official): http://youtu.be/CevxZvSJLk8	0
+417	I WILL FINISH THIS DISSERTATION! And I will laugh in its face when I  finish! Roaaarrr =3	0
+418	https://www.facebook.com/photo.php?fbid=543627485763966&amp;l=0d878a889c	1
+419	I love this shit but I disliked it because it's sucks	0
+420	http://shhort.com/a?r=G8iX5cTKd	1
+421	http://www.ebay.com/itm/131275322914?ssPageName=STRK:MESELX:IT&amp;_trksid=p3984.m1555.l2649 	1
+422	Love it when I can relate to my daughter's music. :-) Katy Perry - Roar  (Official): http://youtu.be/CevxZvSJLk8	0
+423	Please subscribe to us and thank you	1
+424	http://www.wattpad.com/story/26032883-she-can-love-you-good	1
+425	Boooobs	0
+426	me segue ha  https://www.facebook.com/marcos.sousa4?fref=ts	1
+427	https://www.paidverts.com/ref/tomuciux99 esyest money ever. join to our  team!!!!	1
+428	Hey Katycats! We are releasing a movie at midnight UK time to celebrate  katy's 30th birthday! Be sure to subscribe to our channel and watch the  movie!	1
+429	Can you check my videos please? Don't hate me :( Give me one oportunity	1
+430	check out my rapping hope you guys like it  https://soundcloud.com/nereboy/call-of-the-lostproduce-by-atlastatlas-rapper-jkork  and follow and message me :)	1
+431	Our Beautiful Bella has been diagnosed with Wobbler's Syndrome. There is no  way we could afford to do her MRI or surgery. She is not just a dog she is  a very special member of our family. Without the surgery we fear we will  lose her. Please help!  http://www.gofundme.com/f7ekgw	1
+432	In the beginning she was scared off everything and next thing you know she  is out in a leopard bikini!	0
+433	Hey guys. I am a very small YouTuber I upload videos nearly every single  day. I once hope to be like shikas so can you please give 2 mins of your  life to view this channel. and my new video is on how to watch all the  anime for free. and I am pretty sure you wont regret visiting this channel. Thank you!	1
+434	http://www.mediafire.com/download/yvovhafsv5qzzqo/Video.rar    download and  make money today 	1
+435	I love You Katy ♥	0
+436	Katy Perry You Are Looking Soo PRETTY DAMN CUTE :-* :-*	0
+437	This is so stupid. If you Roared at a Lion in real life you'd dede	0
+438	Katy Perry - Roar (Official): http://youtu.be/CevxZvSJLk8. I love this song  and video Katy Perry Rocks Wahoo! 😀😘	0
+439	follow me---&gt; https://twitter.com/f0llowshoutouts 	1
+440	At least she didn't get rid of her completely useless makeup.	0
+441	https://www.facebook.com/pages/Hiphop-Express/704682339621282                 like this page yo	1
+442	i love this so much. AND also I Generate Free Leads on Auto Pilot &amp; You Can  Too! http://www.MyLeaderGate.com/moretraffic	1
+443	Please share and help to my Friend! http://www.gofundme.com/jormis  Thank  you very much!	1
+444	Check out my drum cover of E.T. here! thanks -&gt;   /watch?v=NO9pOVZ9OIQ&amp;list=UUltuCDIHsDeI01by1OW7WuQ	1
+445	http://shhort.com/a?r=HuPwEH5ab	1
+446	iS IN TOP 10 , IN YOUTUBE TOP VIEWS , ON 9 IS MILEY CYRUS: http://www.google.ro/url?sa=t&amp;rct=j&amp;q=&amp;esrc=s&amp;source=web&amp;cd=1&amp;ved=0CB8QFjAA&amp;url=http%3A%2F%2Fen.wikipedia.org%2Fwiki%2FList_of_most_viewed_YouTube_videos&amp;ei=OQ3yU9DWC8L4yQOYo4GoAw&amp;usg=AFQjCNGKM-Idplal6kuVKoEkVgdTT2jVLQ&amp;sig2=OnqZzad3q3CmNBe9nml4gA&amp;bvm=bv.73231344,d.bGQ&amp;cad=rja 	1
+447	https://binbox.io/DNCkM#qT4Q1JB1	1
+448	How do you forget you have a choice, and what the fuck, surviving a plane  crash has a 1/1000 chance of happening.	0
+449	Hey guys go to check my video name "growtopia my story"	1
+450	She is good. Does she make any more music? If she doesn't, she should!	0
+451	ima rapper trying to get notice please check out my mixtape   datpiff.com/mixtapes-detail.php?id=633807	1
+452	I love this song	0
+453	katy perry does remind me of a tiger,like as if its her spirit animal :3 &lt;3	0
+454	katy is mine the girl of my dreams ♥	0
+455	I really like this song.	0
+456	http://hackfbaccountlive.com/?ref=4344749	1
+457	Hi guys i sell Jack Daniel's Hard Back Cover Case for iPhone 5 5S 5C 4S 4   6'' Old Time with only 3 Dollars on Ebay:  http://www.ebay.com/itm/251638183951?ssPageName=STRK:MESELX:IT&amp;_trksid=p3984.m1555.l2649 	1
+458	♫I know someone will see this ♥ I have a dream… I don’t have the greatest videos or the best quality Right now I feel like i'm not getting anywhere and I need your help ♫ If you could possibly watch my videos it means the world to me ♥ Please thumbs this up so others can see… I appreciate it so much ♥♫ Please listen before you hate. Honestly i  appreciate it so much  You don’t have to love me just give this 17 year old a chance	1
+459	Honestly speaking except taylor swift and adele i don't lile any of the  modern day singers. But i must say whenever i hear this song i feel  goosebumps. Its quite inspiring!! Thanks miss Perry!	0
+460	Man she is BEAUTIFUL!	0
+461	Hey yall its the real Kevin Hart, shout out to my fans!!! follow me +RealKevinHeart 	1
+462	My 6th grade teacher looked exactly like Katy Perry come to think of it...	0
+463	katy perry will u sit on my face please. it would be really awesome and  i'll give you 5 dollars. ok if you want to do this then please call me and  subscribe to my channel first ok thats good and then u sit on my face and  ill get an erection then you sit more k?	1
+464	Hey everyone. Watch this trailer!!!!!!!!  http://believemefilm.com?hlr=h2hQBUVB	1
+465	Katy has conqueror's haki &gt;:)	0
+466	*KATY YOU ARE A SHIT GO DIE!!!!!ROAR IS A FLOOOOOOOOOOOOOOOOOOOOP*	0
+467	looooooooooooooooooooovvvvvvvvveeeeeeeeeeee ittttttttttttt	0
+468	this song never get's old &lt;3 	0
+469	Great.This is a song	0
+470	selfie alert	0
+471	Plz help me getting 1.000 Subscribers tonight/today.  Thanks to all who sub me :)	1
+472	A  friend of mine has invented a big dick formula. He had a small dick (4  inches) and he did some research about this topic. During the research, he  found out the secret knowledge of penis enlargement. He applied what he  had learned and now he has a 7 inch dick. He was absolutely amazed by his  results. Of course, it took a few months. Therefore, he has written a book  about this issue in order to help guys. He asked me to promote his book :)  So, guys if you are interested and for more info check this out  https://payhip.com/b/oTIb  . He is selling it for symbolic £1. Thank you ;)	1
+473	Take a break from Katie and help me reach 100 subscribers this month!  New  music and free downloads! 	1
+474	I really love this video.. http://www.bubblews.com/account/389088-sheilcen	1
+475	I've figured out why I dislike this song: it's supposed to be a power  ballad right? Something that's uplifting and motivating by the end.  However, the final chorus has NO extra UMPH at the end and actually sounds  just like the first one. Instead of crescendo-ing into a powerful finish,  "Roar" just says on the same wavelength. It falls flat.	0
+476	Katy perry songs aren't that bad 	0
+477	Click For iTunes code http://shhort.com/a?r=0LaviqU2b	1
+478	Katy Perry is part of me Katy Perry is my firework Katy Perry, I love you	0
+479	i rekt ur mum last nite. cuz da haterz were 2 much 4 meh lik dis if u cri evertim and sponswer mi robox vidz https://www.indiegogo.com/projects/gimme-dem-moneyz-4-roblox/x/8851222#home	1
+480	My three bigger secrets are: I don't think I'm good enough for me skinny  size. I'm bisexual and I sweat really freaking bad when I'm nervous. -  check out Secret by Austin Mahone's NEW Single!	1
+481	please look up DHG SONGS this is my playlist with a bunch of amazing songs 	1
+482	Its almost Katys birthday! October 25th Lets say happy birthday to katy!!! ♥♥♥♥♥♥	0
+483	I'm sorry Katy Perry, I was being weird. I still love you &lt;3	0
+484	https://www.facebook.com/antrobofficial	1
+485	It is a shit	0
+486	Katy perry is and inspirational singer her voice is awesome and I loved her  so much . She achieved music history and I couldn't believe that . Guys if  you could take 1min to go to my channel and watch my first video I would be  greatly thankful for this :) ty guys N katy is still awesome xoxo	1
+487	This song means so much to me thank you  soooooooooooooooooooooooooooooooooooooooo much:-) Xxx	0
+488	I love this sooooooooooooong I love katy perry	0
+489	The TREU DETECTIVE alternate ending! __ http://www.funnyordie.com/videos/d15fb87973/true-defectives	1
+490	C'mon Katy you are so close to 14,000,000 subscribers...come up with  another hit like this and it will happen	0
+491	http://www.aaas.org/tech-i/vote#view/25874/2177367 VOTE SHELDON PLEASE  GUYS. GIVE IT 5 STARS. THANKS IN ADVANCE	1
+492	I hate videos like these with those poor animals.	0
+493	she is a fool. this is a foolish video. the lyrics are all about her and  how great she is. she is arrogant, obviously. but the thing is that there  are a lot of idiots.. look how many hits this garnered. the young stupid  girls who listen to her are numbering in the millions i guess. this video  is a piece of trash. why would anyone ever like it or listen to it? because  these stupid little girls have idiots for fathers and mothers.. and so they  are going to turn into idiots also.. just like this stupid singer. 	0
+494	00 : 39 Im pretty sure that tiger just wanted a hug	0
+495	Y LOVE YOU	0
+496	hey guys!! visit my channel pleaase (i'm searching a dream)	1
+497	this video is great .....!!! I love this........and like much katy perry	0
+498	Katy Perry is garbage. Rihanna is the best singer in the world. 	0
+499	The Perry you're doing a good job good job I love all of their videos and  by the way can you please describe to my channel please please please  please I'm trying to get as many comments to Skyburst lights is a cancer  and get famous please	1
+500	DAMNNNNNNNN, she is sexy O_O	0
+501	OMG I LOVE YOU KATY PARRY YOUR SONGS ROCK!!!!!!!!!!!!!!!!! THATS A TOTAL  SUBSCRIBE	0
+502	http://thepiratebay.se/torrent/10626048/The.Expendables.3.2014.DVDScr.LEAKED.CLEAN.XviD.MP3-RARBG 	1
+503	I love her green eyes	0
+504	Subscribe me and i subscribe you back!!	1
+505	This is the best of the best video in world!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	0
+506	why tiger succumbs to a beautiful girl ?.probably male tiger.....????   ha.ha.ha..	0
+507	Please Subscribe In My Channel →	1
+508	https://www.facebook.com/pages/Komedi-burda-gel/775510675841486	1
+509	should not have paused the music, this is a clip, not a movie.	0
+510	http://9gag.com/gag/aAVpwj9/ =)	1
+511	Hello! I'm Marian, I'm a singer from Venezuela! I was part of a boy-girl band named cubik, and I'm now singing on my own  'cause I wanted to play my own pop and pop-rock songs.  It would mean a lot if you could have a look at my channel to check my  music and watch my new video!! and if u like, subscribe to it! XOXO THANKS!!  PS: if you like a lot my channel, you can share it with your friends!!  Haha!! LOL MARIAN	1
+512	http://www.amazon.com/Knight-Dawn-cursed-Daniel-N-ebook/dp/B00MPPQHRI/ref=sr_1_7?s=digital-text&amp;%3Bie=UTF8&amp;%3Bqid=1408122684&amp;%3Bsr=1-7&amp;%3Bkeywords=knight&amp;tag=wattpad-20     some people are very talented but some are more talented but there is no  sponsor	1
+513	:-D ♪♪♪ This is my favorite song ♥	0
+514	Man she is hot in this one, male companions visit my profile to check out  the coolest sextoy ever made!	1
+515	Hey Guys this is Glamour Beauty! I just started my youtube channel please  go check it out! I'm going to post singing videos and also random videos  that I fell like! Please go to subscribe! More to come soon!. Remember to  subscribe!	1
+516	Subscribe and u are gonna hear me roar ;)	1
+517	Help Please!!  http://www.gofundme.com/RJanimalcare	1
+518	(( rapid facebook )) the free version of all colors and is characterized by  fast and beauty download now https://play.google.com/store/apps/details?id=com.rapid.facebook.magicdroid	1
+519	Imagine this in the news crazy woman found acting like a tiger and bit jims  ear off	0
+520	That's Good :)	0
+521	Please look at my channel	1
+522	If interested in making extra money by internet use the next link www.swagbucks.com/refer/Nonturtle02	1
+523	Also LuckyMusiqLive she probably could help u make it big because I think u  have talent. Just look her name up on the internet. Hit me up when u get  this message	1
+524	Nature is beautiful, no need to place tits in video to impress people.	0
+525	Katy has the voice of gold. this video really brings the lyrics to life. I  am a rapper with 25000 subscribers.. thumbs up if you hate when you take a  shit and the water splashes your nuts	1
+526	Please help me give my son a grave.  http://www.gofundme.com/BishopsGraveMarker Or please just share it on your  fb page, I do not have one anymore.	1
+527	https://viralangels.com/user/d4aaacwk	1
+528	Maybe the best music video in the last 15 years? This is how pop music is  done folks!	0
+529	Hey! I'm NERDY PEACH and I'm a new youtuber and it would mean THE ABSOLUTE  world to me if you could check 'em out! &lt;3  Hope you like them! =D	1
+530	EVERYBODY PLEASE VOTE KATY AT EMA 2014 !!!!!  Best song  Best female Best pop Best live Best look  Best video  PLEASE VOTE KATY !!!!! PLEASE PLEASE PLEASE !!!!! VOTE VOTE VOTE !!!!! KATY KATY KATY !!!!!	1
+531	https://www.tsu.co/ToMeks Go register ;) free money;)	1
+532	3:46 so cute!	0
+533	https://www.facebook.com/photo.php?fbid=313454548839369&amp;set=a.207230212795137.1073741825.100005244783212&amp;type=1&amp;theater  1111111111111111111	1
+534	If you looking for new music to listen to check out cece song called  dead2me on iTunes music video out now right here on youtube.	1
+535	Hey guys! I've made a amazing Smiley T-Shirt.Of all the things you wear,  your expression is the most important and remember all the statistics in  the world can’t measure the warmth of a smile. If you're a big fan of  T-Shirts and want to gets more happiness, it's perfect for you.  Check this  out and buy it at www.teespring.com/smiley12 =)) thanks you guys so  much!!! 	1
+536	Like my page please...  https://m.facebook.com/Dreaddis?ref=m_notif¬if_t=fbpage_fan_invite&amp;actorid=1442646731  	1
+537	It looks so real and my daughter is a big fan and she likes a lot of your  songs.	0
+538	I'd rather hear some propa explicit gangsta rap insted of this garbage.  This song is trash !	0
+539	follow me on instagram bigboss286	1
+540	Who else would give Katy perry a good old mighty roar? ;)	0
+541	Such a good song ans Katy sounds great over the melody. Im growing as an  artist everyday from my hit song 'CRAZY' which has got my name out there.  cant thank my fans more for their support. If you could take a moment? to  check it and my music? maybe you'll join me to make my dream come true :)  thank you for your time	1
+542	Watch my videos xx	1
+543	Is that tiger called 'Katty Purry'?	0
+544	i am a big fan of you and i love you	0
+545	This is fucking shit. From the first notes, that becomes clear. Complete  and utter shit. May God come and cleanse the earth of the complete and  utter idiocy which is humankind.	0
+546	WOW VERY NICE CONGRASULATION I LIVE SO MUCH http://en.prothom-alo.com/sport/news/53331/Zimbabwe-A-team-due-in-Dhaka-Wednesday 	1
+547	Lets be honest, you wouldn't last 1 day on your own in the jungle. Stop  living n a fairy world.	0
+548	I love katty perry	0
+549	Nice song .See my new track.	1
+550	I'm not a big fan of the song but this video is awesome!	0
+551	curti? click here ?  https://www.facebook.com/demiilovatofas?ref=hl   https://www.facebook.com/pages/Frases-Secretas/448800865296855?ref=hl	1
+552	http://www.googleadservices.com/pagead/aclk?sa=L&amp;ai=CSyOEB1wxVPCfL7D27AbGpYDgBJDPm6IH6MHu05wBgJXbv8sBEAEgkN7lJVCF7byH_f____8BYIOFgICAHKABsKG31wPIAQKpAh_KmC0hBpM-qAMByAPBBKoEkwFP0KullxqI1MG6o43HVzE-eFMqRG4Tu5LLBU_fsZ8gn0HBkJhBX-m83W1TS3_3Dx_HwPdX1Kazsj8o7SIEcVJjmBNsWyiJEcqvHXLbdzStUBOFaloYInWm0_rOOCppS2AuAT6zguICKm0lI83duwMAbzqvenE8TRfAzOrltBb037VzYv_XI4hBNQ2nvh19MrBgE0SIBgGgBgKAB7jeyCg&amp;num=1&amp;cid=5Ggs_m_9mA3TI40fS6mVPICS&amp;sig=AOD64_1OFC7Seh_1pOp-jYrbS7X6-heeNQ&amp;client=ca-pub-8252267209931889&amp;adurl=http://blog.2parale.ro/2014/10/challenge-pentru-afiliati-aplica-la-noile-programe-de-afiliere-si-fii-business-boomerul-lunii/&amp;nm=2&amp;mb=2&amp;bg=!A0QoUc7Q48v3_QIAAABpUgAAACsqAR0_VgOQWQxjmPUyvKoSf3K-q1BvKf9ZE4jhNC3ovckKxCbAFzZpAJiBXWBvVq4jrDgZ8q3rInlwgaBy_bXlfw7ma6dk0RJG14ZkRyizwqdi7HxgGE9tNDD9abflTFkBMbFfcJixNtHwbwkJ6N2onLH2D9EvEagPhoEwXOgBnu5ibgtRkgnAcQ1OIbgMzgAFNSc0lsaRiqj8HQR8T12dWv_7biY4k6I3y4yubloTdE_4XVKlnVeADZzF1L_xRYQkE6Wsur3EdLJWGk8fLq_QALdI-wAzNuysgqjNRDY6VucKLplZONyiSdKc9ebX-0dbHjZdW0LbsJBi40gXm0D0p5KRhv8XInQlI53__wQBaHS8zX4MJHw5vWrkPXFOeKs 	1
+553	She's awesome XD	0
+554	i like this song because of all the animals and i like this song .	0
+555	https://www.facebook.com/pages/Nailey-nicool/629410220489046?ref=hl like  mee	1
+556	DOWNLOAD RAPID FACEBOOK FOR FREE NOW https://play.google.com/store/apps/details?id=com.rapid.facebook.magicdroid	1
+557	The new Makeup Transformation with Integrated Look-A-Like Feature now  available in Ver 1.13! Do you know who you look like? Install or update  your version of Makeup Transformation  https://play.google.com/store/apps/details?id=com.yourelink.makeuptransformation  or visit Google Play.	1
+558	http://vimeo.com/106865403	1
+559	Katy Perry is lion	0
+560	Katty perry please say in one of your new videoes that they follow the  Girls Girls please	1
+561	I love KATY PERRY &lt;3 &lt;3	0
+562	  HI!:D!:) We’ re TWIN MELODY ,17 year old twins :) WE DID SOME COVERS!!WE DID A COVER OF BIRTHDAY BY KATY PERRY!!  PLEASE JUST TAKE 1 SECOND AND WATCH IT!! THANKS,MERCI,GRACIAS,DANKE,OBRIGADO,GRAZIE ….    !!!  &lt;3  XX HAVE A NICE DAY!!:D	1
+563	Great video by a great artist in Katy Perry! Admire her creativity! Check  out our channel for no nonsense mobile tech reviews and comparisons as well  as an iPhone 6 and 6 Plus review and comparison!	1
+564	What does that tattoo on her right tricep say?	0
+565	Glad to know im not the only one who knows its katheryn's birthday today :)  happy birthday katy...and my sister. (they seriously have the same  birthday...)	0
+566	&lt;3	0
+567	i love you katy perry because you will sing nice than shakira	0
+568	https://apps.facebook.com/my-polls/utsitcompetition2014?from=user_link&amp;ref_id=ouxg5e .  Please open this link and vote for  anand niketan international school, the  project name is project Fr-e-dom and share it with  yours friend.	1
+569	Visit my channel	1
+570	If only I had the body like Katy Perry :)). She looks so hot. I love her  ^^!	0
+571	666,002,018 views! 666 million. 666! Katy Perry illuminati confirmed!!!	0
+572	I hear this all the time on radio and its really irritating. That being  said, i like the video	0
+573	https://www.facebook.com/myfunnyriddles	1
+574	j aiiima plzzz  https://www.facebook.com/pages/%C3%89c%C3%B8l%CE%B5-al-ma%CE%B7b%CE%B5t/302703146601369 	1
+575	https://www.facebook.com/photo.php?v=4483179854075&amp;set=vb.1727483389&amp;type=2&amp;theater 	1
+576	Katy Perry can't sing for shit. All i hear it autotune.	0
+577	Katty is the best! ! ! ! 	0
+578	Free itunes $25 giftcard codes: http://shhort.com/a?r=OOCnjqU2b	1
+579	she is horrible at acting. cringe-worhty.	0
+580	Katy Perry's songs are the best of the songs of women artists.	0
+581	She is a perfect wonder.....	0
+582	I just realized that this verses in this song have the exact same melody as  the verses in "Hey Ho" by the Lumineers.	0
+583	help me click on the subscribe Mai Nguyen, thank you	1
+584	plz subscribe to my channel i need subs and if you do i will sub back i  need help	1
+585	this song gives me strength! love her lyrics. this video really brings the  lyrics to life. I'm getting known fast because i rap with meaning.. thumbs  up if you piss next to the water in the toilet so its quiet...	1
+586	Like Gorlin-Goltz Syndrome Support Community, you are not alone. https://www.facebook.com/GorlinGoltzSupport	1
+587	Since she is a feminist champion, why would she want to reinforce the  stereotype of girls being girly by painting the nails of an elephant that  probably wouldn't even appreciate it?	0
+588	Anyone Who LOVEs music , please go check out my youtube page and tell me  what you think . I just put a video up and will be doing more song. I'm  just trying to get myself started. Any love is much Appreciated 	1
+589	Nice! http://www.barnesandnoble.com/s/BDP?csrfToken=I9tIxe8cNkCosOqkEMGjLU5uwv6nSXSO&amp;sort=DD&amp;size=90&amp;csrftoken=89Iyec7nrWP5NYtnO5U7amhVmfLUtgGL&amp;dref=5094&amp;keyword=BDP&amp;store=EBOOK 	1
+590	Check out our vids, our songs are awesome! And that I guarantee :)	1
+591	Good thing she brought her spray-on tan, hairstylist, makeup artist, and  cameraman.	0
+592	If she really did this there she's hardcore	0
+593	FREe ITunes Gift card http://shhort.com/a?r=x6J4gBrne	1
+594	When I hear Katy singing this, I cry. The song has got depth.	0
+595	Wow she is sexy XD	0
+596	#nowplaying "Weekendmix" track 04 : Katy Perry - Roar (DJ Denis Shmelev &amp; DJ Alex Serov Remix) http://youtu.be/CevxZvSJLk8   Listen live at: www.smartfm.nl/livestream.html	1
+597	Awesum song!! Jus luv it!	0
+598	 I love your music 	0
+599	It should be illegal to be this goodlooking as this babe is...	0
+600	see it all, human folly right?	0
+601	katy is beautiful. and this song is catchy. I'm a rapper with 25000  subscribers.. thumbs up if you hate when you take a shit and the water  splashes your balls	1
+602	Check out this video on YouTube:Facebook lhudygirlamaya 	1
+603	katy perry  just stop it and dont do a song i dont like it	0
+604	My telephone!	0
+605	Nice song	0
+606	Hey guys subscribe to my channel for no reason! Please!..	1
+607	The great mother of the jungle. Sweet and natural. I like her videos.	0
+608	I love that you subscribed	1
+609	Hey ! I know most people don't like these kind of comments &amp; see at spam,  but I see as free advertising . So please check out my cover of Sparks Fly  by Taylor Swift ! It is not the best ever I know, but maybe with some  encouraging words of wisdom from many of you I can become better! Please go  to my channel and check it out !	1
+610	Follow me watch my videos :) Follow me watch my videos :) Follow me watch  my videos :) Follow me watch my videos :) Follow me watch my videos :)  Follow me watch my videos :) Follow me watch my videos :)	1
+611	I like you . Katy Perry 600▲60▲6▲	0
+612	She named the tiger Kitty Purry  No, seriously, she did, check the video 	0
+613	i think they were drunk when they shot the first half of the video and then the sec on half comes in, and her boobs are magically bigger and  she's more beautiful suddenly, and the dude practically vanishes 	0
+614	This comment will randomly get lot's of likes and replies for no reason. I  also like Jello. Strawberry jello.	0
+615	I &lt;3 Katy Perry!	0
+616	Subscribe me please	1
+617	katy perry is awesome	0
+618	For latest movies 2014 please visit this site  http://www.networkedblogs.com/p/11cPWb?ref=panorama	1
+619	https://www.reverbnation.com/slicknick313/songs	1
+620	Watch Maroon 5's latest 2nd single from V (It Was Always You) www.youtube. com/watch?v=TQ046FuAu00	1
+621	Please subscribe every sub gets a shout out tell me what type of videos u  want and I will try make it happen	1
+622	Since when has Katy Perry had her own YouTube channel?	0
+623	https://soundcloud.com/artady please check my stuff; and make some feedback	1
+624	Roar is without a doubt your best song...feel good song with a message for  everyone. Good job Katy	0
+625	Sign up for free on TSU and start making money on social media add/follow  me and ill add follow you!! http://tsu.co/Roberts9010	1
+626	Put famous people in the jungle for an hour and look what happens.	0
+627	This video is so racist!!! There are only animals.	0
+628	She is fit	0
+629	I fucking hate her. Why? Because she don't write her songs she got  producers for that. Second why tha fack is she in every song almost nude  mayby because she's an attention hooker.	0
+630	I did a cover if u want to check it out THANK U.....Michael Age 8	1
+631	I really don't understand how this has 600 million views lol. I'm not  hating or anything, it's just confusing. 	0
+632	my son love so much	0
+633	Hey guys plz check out my youtube channel to c funny 2 girls 1 cup reaction  thanks and plz subscribe! Thanks	1
+634	In my opinion I think you look better with black hair then blond hair : )	0
+635	http://psnboss.com/?ref=2tGgp3pV6L this is the song	1
+636	I love this song, it´s for empowering every woman :)  "you´re gonna hear me roar" ;)	0
+637	I LOVE YOU KATTY PERRY &lt;3 	0
+638	Nicee!!sabrosura viva https://soundcloud.com/yerki-elinmigrante/yerki-myb-move-your-body	1
+639	http://www.billboard.com/articles/columns/pop-shop/6174122/fan-army-face-off-round-3 Vote for SONES please....we're against vips....please help us.. &gt;.&lt;	1
+640	Subscribe to My CHANNEL	1
+641	Katy Perry - Roar (Official)  #soundsofsunday   #music  	0
+642	Hi Guys! check this awesome EDM &amp; House mix :) thanks a lot..  https://soundcloud.com/soundhase/edm-house-mix-2	1
+643	Awesome video this is one of my favorite  songs😀😀😀😀😀😀😀❤️❤️❤️❤️❤️❤️❤️❤️💎💎💎💎💎💄💄💄💄💋💋💋💋	0
+644	The song is very good ...but the video makes no sense...just a nonsense  video...I mean she is telling her story of being stuck on an island, but  the song doesn't fit in the situation...but nvm...The song is good	0
+645	http://www.bubblews.com/news/6401116-vps-solutions	1
+646	THIS IS A COMPETITION TO MEET MY IDOLS, IT WOULD MEAN SO MUCH IF YOU GUYS  WILL LIKE THIS PIC ON IG! http://instagram.com/p/smZdivopxb/	1
+647	I love the song roar it make me think am fill the way	0
+648	Www.youniqueproducts.com/joannagordon Younique by Joanna Gordon www.youniqueproducts.com	1
+649	Love it	0
+650	Hii youtube	0
+651	Where did she find all that make up in a freakin jungle?!	0
+652	You gonna hear me ROARRRR.....	0
+653	Hey Guys, I know you tend to skip these comments, but Take a look first.  I am Free Fire, A Big Room, Progressive House, Deep House, Dubstep &amp;  Chillstep Producer from a small town in Missouri. Down here, I have no  support from the community, and all I really ask is some critiquing,  Support, and Views.   My Music has gotten much better then what it was 10 months ago, and I  promise my new content should be more then satisfying. Soon I should be  able to advertise and I wont pester you guys anymore.  So gimme a chance and check out my music, thanks and god bless.	1
+654	You are all sheep, period. This is terrible music.	0
+655	I love this song!!!	0
+656	People who believe in Illuminati are stupid the same people believe that  911 was made by the American government	0
+657	She's got it all. Incredible voice, extremely hot, nice tits	0
+658	I love katy fashions tiger, care to visit my blog sinar jahitan  I also have the tiger collections tqvm	1
+659	http://minhateca.com.br/mauro-sp2013/Filmes+Series+Desenhos+Animes+Mp3+etc	1
+660	 GO TO MY CHANNEL and check out my written songs	1
+661	That was very good I mean very very much good 	0
+662	NOKIA spotted	0
+663	This looks so fun and it's a good song	0
+664	Hi everyone! Do you like music? Then why not check out my music channel.  The LEXIS band will be uploading their own songs and covers soon so don't  miss out. Please SUBSCRIBE too as it does help us out a lot. Just takes one  click. -&gt;	1
+665	This Song Was Good Until It Got Overplayed The Hell Out Of On Radio	0
+666	VOTE FOR KATY FOR THE EMAs! #KATYCATS  http://tv.mtvema.com/artists/katy-perry/i38xh1	1
+667	https://www.facebook.com/profile.php?id=100007085325116	1
+668	Hey guys!  Can you please SUBSCRIBE to my channel,because I'm gonna filming a video so  hope you guys like it and thank you so much for you're support! Xoxo,Leah!  &lt;3 	1
+669	http://thepiratebay.se/torrent/10626835/The.Expendables.3.2014.DVDSCR.Xvid-DiNGO 	1
+670	hi beaties! i made a new channel please go check it out and subscribe and  enjoy!	1
+671	http://www.ermail.pl/dolacz/UnNfY2I=                  Please click on the  link	1
+672	subscribe to me	1
+673	who is going to reach the billion first : katy or taylor ?	0
+674	Its a good song and i like her video clip, because its a bout a girl that  her airplane crashed on a land far far away... and she found the way to  survive! And i love the pet tiget too( Kitty Purpy) lol :D	0
+675	Hey guys! Please join me in my fight to help abused/mistreated animals! All  fund will go to helping pay for vet bills/and or helping them find homes! I  will place an extra emphasis on helping disabled animals, ones otherwise  would just be put to sleep by other animal organizations. Donate please. http://www.gofundme.com/Angels-n-Wingz	1
+676	"....because I AM a champion...and you're gonna hear me roar!"   Today I AM my own champion  Today I AM a champion for the Creator  Today I AM doing positive in my world Today I AM blessing and healing all around me Today I AM successful and  creating success  	0
+677	This song is so AWESOME!!!She made everything stand out and all the viewers  could imagine themselves at the setting of the video. Awesome Job Katy  Perry!!!!	0
+678	I loved, she is amazing.. OMG your eyes*_*	0
+679	blue eyes, can't be trusted.  uranus bless america.	0
+680	PLEASE VOTE FOR ME FOR THE WET SEAL MODEL 2015 CONTEST! MY INSTAGRAM  USERNAME IS destinyforever_  http://www.wetseal.com/modelsearch15/modelsearch15.html	1
+681	HAPPY BIRTHDAY KATY :) http://giphy.com/gifs/birthday-flowers-happy-gw3JY2uqiaXKaQXS/fullscreen  (That´s not me)	1
+682	really want this video to get 1 billion views, would be amazing!	0
+683	want to win borderlands the pre-sequel? check my channel :)	1
+684	I #votekatyperry for the 2014 MTV #VMA Best Lyric Video! See who's in the  lead and vote:  http://on.mtv.com/Ut15kX	1
+685	She's an old Whore!	0
+686	Perfect! &lt;3	0
+687	Your going to hear me Roar !!!! :-))))   #soundsofsunday 	0
+688	I started hating Katy Perry after finding out that she stole all of the  ideas on her videos  from an old comic book. Yet, her music is catchy. 	0
+689	this video is very inaccurate, a tiger would rip her face of	0
+690	Hi! I would appreciate it if you all could help to like this poster on  Facebook:  https://www.facebook.com/nyghdramafest2014/photos/a.333608120156973.1073741830.327568907427561/333607726823679/?type=3&amp;theater 	1
+691	My honest opinion. It's a very mediocre song. Nothing unique or special  about her music, lyrics or voice. Nothing memorable like Billie Jean or  Beat It. Before her millions of fans reply with hate comments, i know this  is a democracy and people are free to see what they want. But then don't I  have the right to express my opinion? Please don't reply with dumb comments  lie "if you don't like it don't watch it". I just came here to see what's  the buzz about(661 million views??) and didn't like what i saw. OK?	0
+692	This song makes me want to drink bleach	0
+693	Please check out my acoustic cover channel :) thanks 	1
+694	great song you go katy!	0
+695	https://m.facebook.com/story.php?story_fbid=764484966942313&amp;id=754989901225153&amp;ref=stream  gf	1
+696	Check out my covers I have a video coming out  please subscribe	1
+697	Thank you KatyPerryVevo for your instagram likes @axeljonssons	1
+698	Are those real animals	0
+699	check out mah girl it duh shit yo	1
+700	2011- the last year of decent music.	0
+701	Check out this playlist on YouTube:	1
+702	That shake and that golden and black robot were really partying 	0
+703	likeeeeeeeee	0
+704	Check out this video on YouTube:<br /><br />	1
+705	Hi everyone! Do you like music? Then why not check out my music channel. The LEXIS band will be uploading their own songs and covers soon so don&#39;t miss out. Please SUBSCRIBE too as it does help us out a lot. Just takes one click. -&gt;	1
+706	Check out this playlist on YouTube:	1
+707	super nice, love musique	0
+708	2015 LIKEEE	0
+709	Check out my youtube channel for cool beatboxing (:	1
+710	they Shuffle hard that they made an Earthquake in Nepal	0
+711	So,cool!!	0
+712	This awesome song needed 4 years to reach to 800 mil views while Tango Switch bitch needed 1 year. its not fairrrrrrr	0
+713	Check out this video on YouTube:	1
+714	Check out this video on YouTube:	1
+715	Check out this video on YouTube:	1
+716	Guys lmfaois going to have a reunion on June/27/15 at expo Idaho gardens there going 2 performe live and there&#39;s going 2 be a press conference after to answer questions about were there been its all live on my channel I got publishing rights subscribe to my channel to watch the whole thing if u don&#39;t believe me look it up hope to see all u true fans there and yes it&#39;s free 	1
+717	Check out this video on YouTube:	1
+718	870,000,000 views...566,000 comments...oh my lanta	0
+719	Check out this video on YouTube:	1
+720	epic	0
+721	Hey Guys Jaylan Here And I Just Wanted Everybody To Know if Yall Can Subscribe Please That Would Be Great! This Is My New Gaming Channel (Have Not Uploaded Yet)But I Will Upload If I Hit About 10 Subs Thank You :D	1
+722	go check out our video	1
+723	I LOVE YOUR SONGS	0
+724	This the best song i ever hire<br />	0
+725	I love it	0
+726	Stop,is a very TOP 1	0
+727	Check out this video on YouTube:	1
+728	Check out this playlist on YouTube:	1
+729	Check out this video on YouTube:	1
+730	wow!!!!!! increible song!!!!!!!!!	0
+731	Wow dance show	0
+732	Check out this video on YouTube:	1
+733	los invito a subscribirse a mi canal 	1
+734	Cool Video LMFAOVEVO! You should check out my shuffling videos on my channel when you get the chance. It&#39;s much appreciated. Cheers! <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23EveryDayImVaping">#EveryDayImVaping</a> ;-)	1
+735	I remember back when this was so popular everyone on our school was shuffling it was crazy	0
+736	View 851.247.920<br /><br /> Best youtube Video<br />If Subscribe to My Channel please!<br />Thank you! &lt;3<br /><br />Melhor Vídeo do youtube<br />Se Inscreva no Meu canal por favor!<br />Obrigado! &lt;3<br /><br />Mejor Video youtube<br />Si suscriba a mi canal por favor!<br />Gracias! &lt;3<br /><br />Meilleur vidéo youtube<br />Si vous abonner à Ma Chaîne se il vous plaît!<br />Merci! &lt;3	1
+737	Shuffling all the way with LMFAO! I like this one, wish I could shuffle like these crazy dudes	0
+738	Thumbs up if you&#39;re watching in 2015	0
+739	Check out this playlist on YouTube:י<br /><br /><br /><br />⛺🏤⛺⛺	1
+740	Shuffle	0
+741	Check out this video on YouTube:	1
+742	Check out this playlist on YouTube:	1
+743	I fuckin love this song!<br /><br /><br />After, I&#39;m sexy and I know it 	0
+744	Check out this video on YouTube:	1
+745	Check out this video on YouTube:<br />Looking for my wood pile...<br />Anyone seen it? Hmm?&#39;<br />How bout some pie now <br />Ladies?&#39;	1
+746	watch this with sound off!	0
+747	Check out this video on YouTube:..🌈🌈🌈	1
+748	SUBSCRIBE TO MY CHANNEL	1
+749	Subscribe to me if u think &quot;swag&quot; is fucking stupid	1
+750	cool song ever good thing its here	0
+751	YouTube collaborator&#39;s anyone? subscribe to me, I&#39;ll subscribe back. I also will start uploading more YouTube videos.  <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23LMFAO">#LMFAO</a>	1
+752	Check out this video on YouTube:	1
+753	Check out this video on YouTube:	1
+754	Check out this video on YouTube:	1
+755	Check out this video on YouTube:	1
+756	Check out this video on YouTube:	1
+757	not 2 million view anymore :))	0
+758	never gets old	0
+759	Check out this video on YouTube:الإعلانات<br /><br /><br /><br /><br /><br />لل	1
+760	Check out this playlist on YouTube:	1
+761	man check out the raps on my channel im better rapper than these nigger fools	1
+762	Check out this video on YouTube:	1
+763	Check out: My Hood Doh, by Flixter Nossnevs	1
+764	subscribe to my chanell	1
+765	Check out this playlist on YouTube:	1
+766	i love this song so much!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	0
+767	Subscribe I&#39;ll subscribe back	1
+768	i&#39;m watching this on summer 2015	0
+769	Check out this video on YouTube:	1
+770	What&#39;s with the guys jacket? 1 sleeve and a sleeveless arm	0
+771	LMFAO!!!!!!!!	0
+772	LMFAO is CRAZY DOPE!!! CHeck out my music on my channel if you have a minute, it would mean a lot. Much love!!	1
+773	:)	0
+774	Love this soooooooooooooooooooooooooooooooooooooooooooooooooo much	0
+775	I am going to blow my mind	0
+776	Gooooood	0
+777	Check out this playlist on YouTube:a	1
+778	Check out this playlist on YouTube:	1
+779	very good song:)	0
+780	I lovee it ♥	0
+781	3m subscribers but look at the views its 700 million wtf :P	1
+782	Check out this video on YouTube:	1
+783	Check out this video on YouTube:jjiwajwkajajqjaubkm	1
+784	Party Rock<br />	0
+785	This is so cool,why haven&#39;t I heard this before?	0
+786	You are the music hero😍😎	0
+787	mindblowing dance.,.,.superbbb song	0
+788	just :( superr!!!	0
+789	Who df is Lauren Bennett..	0
+790	Who knows the name of that girl?.. and that one.. and that one...	0
+791	ALL SCHOOL DROP OUTS I KNEW AS FRIENDS BEFORE THEY DECIDED TO DROP SCHOOL THINK THERE IS NO NEED FOR AN ID CARD OR A CERTIFICATION TO PROVE YOU ARE AN EDUCATED CLEAN IN CRIMINAL RECORD TALENTED PERSON TO WORK IN ANY ENTERTAINMENT FIELD WORLDWIDE. THEY THINK THEY COULD BE RICH ENTERTAINERS BY CONSOLIDATING WITH ACTORS / ACTRESSES AS WELL AS SINGERS FOR A SHARE OF PROFIT(S).	1
+792	like this comment if you&#39;re watching this video when big bang happened<br />i do	1
+793	if watching in 2015 subscribe to <span class="proflinkWrapper"><span class="proflinkPrefix">+</span><a class="proflink" href="https://plus.google.com/104999962146104962510" oid="104999962146104962510">SuperMarioLogan</a></span> 	1
+794	like this comment if your watching in 2015 or 2016	1
+795	Good	0
+796	Why sooooo many downs?	0
+797	🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨<br />NOW THAT I HAVE YOUR ATTENTION!!! 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨SUBSCRIBE TO MY CHANNEL!!!<br />🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨<br />GRACIAS!💋<br />🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨	1
+798	this song is fu cking awesom!!!!!!!	0
+799	Lets party	0
+800	subscribe to my channel yo - DJ Feelz	1
+801	More views than nikki minaj Anaconda	0
+802	Check out this video on YouTube:	1
+803	 are  there  people who like  this shit? hahahahah :P<br />this isn&#39;t even real music.......... it&#39;s computer-based music.....<br />well done music industry.....	0
+804	Nezo and Ed<br />Like&amp;share this page on facebook please 	1
+805	Good times. 	0
+806	old and good song	0
+807	XD I WAS GOING CRAZY FUCKIN CRAZY WATCHIN THIS BRAH	0
+808	LIKE AND SUBSCRIBE	1
+809	Check out this video on YouTube:	1
+810	lol so funny love it	0
+811	<a href="http://www.youtube.com/watch?v=KQ6zr6kCPj8&amp;t=2m19s">2:19</a> best part	0
+812	Check out this video on YouTube:	1
+813	Every day I&#39;m Shuffling !	0
+814	Check out this video on YouTube:	1
+815	Check the shit out on my channel<br /><br /><br />SUBSCRIBE YOU WILL LIKE IT	1
+816	Check out this video on YouTube:it is a old track but it still bad	1
+817	Awesome	0
+818	Come subscribe	1
+819	Almost 1 billion views, nice.	0
+820	wait I SAW A KID NOT KIDDING	0
+821	everyday I&#39;m shufflin	0
+822	Check out this playlist on YouTube:<br /><br />	1
+823	Omg	0
+824	omg	0
+825	Check out this video on YouTube:	1
+826	Check out this playlist on YouTube:hbbhhhgh	1
+827	Who else saw jesses dancing sorry if I spelled it wrong  peace✌	0
+828	everyday I&#39;m shufflin	0
+829	hi everyone this is cool check out sexy and i know it	1
+830	Check out this video on YouTube:hjalp	1
+831	good party	0
+832	I miss this song. 😢	0
+833	:3	0
+834	my favorite song	0
+835	Check out this playlist on YouTube:	1
+836	I could finally do the spongebob but I started yesterday XD	0
+837	Who&#39;s watching in 2015 Subscribe for me !	1
+838	shuffle!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	0
+839	Check out this video on YouTube:	1
+840	Check out this video on YouTube:	1
+841	In my head this is like 2 years ago.. Time FLIES	0
+842	Party rock! XD	0
+843	share your thoughts	1
+844	Take a look at this video on YouTub<br />You	1
+845	I miss when people dressed like this.	0
+846	This Will Always Be My Favorite Song<br />But My Favorite Part Is <a href="http://www.youtube.com/watch?v=KQ6zr6kCPj8&amp;t=3m40s">3:40</a>-<a href="http://www.youtube.com/watch?v=KQ6zr6kCPj8&amp;t=4m11s">4:11</a> In The Video 	0
+847	cooooooooooooolllllllllll	0
+848	Check out this video on YouTube:	1
+849	awesome	0
+850	subscribe	1
+851	Check out this video on YouTube:fb i	1
+852	HOW MANY THUMBS UP FOR LOUIS SAVING THE DAY!?!?	1
+853	Check out this playlist on YouTube:👿👳👳👳👳👳	1
+854	1000000000 views.	0
+855	Check out this video on YouTube:	1
+856	Please subscribe to my channel!Thanks!	1
+857	Check out this video on YouTube:	1
+858	this is increidebl	0
+859	JUST DANCE 3 😂😂😂	0
+860	Wow I love it 	0
+861	everyday i&#39;m subscribe	1
+862	Check out this video on YouTube:	1
+863	Check out this video on YouTube:	1
+864	Like this comment, guys i just started up a new channel if i can get 200 subscribers by tonight ill do a $20 paypal giveaway like this comment so your friends can see it or others and they can also be entered GO !!!	1
+865	i want to be that robot guy...	0
+866	Anyone else think this video theme is a bit of an insult to 28 days later? 	0
+867	subscribers please`	1
+868	OOOOO SHUFFLLLLLLLLLLLLLLLLLLLLLLINNNNN	0
+869	Party rock anthem is love,party rock anthem is life	0
+870	SERIOUSLY HOW DID THEY COME UP WITH THAT BEAT IT IS INSANELY GOOD HOLY FUCK	0
+871	Dance :)	0
+872	cool	0
+873	2015!! LLIKEE!!	0
+874	LMFAO - Party Rock Anthem ft. Lauren Bennett, GoonRock.	0
+875	Check out this video on YouTube:	1
+876	Check out this video on YouTube:	1
+877	Check out this video on YouTube:	1
+878	Please become my first subscriber.  Thank you.	1
+879	Laughing My Fucking Ass Off!!!	0
+880	i like this steps...	0
+881	give it a like	1
+882	Dear friends please subscribe to my channel I will be very glad to see You virucide ;-)	1
+883	Check out this video on YouTube:	1
+884	love this song	0
+885	Check out our app to solve all your party/drunk problems! <br /><a href="https://play.google.com/store/apps/details?id=vn.ibit.AppLocker&amp;hl=en">https://play.google.com/store/apps/details?id=vn.ibit.AppLocker&amp;hl=en</a>	1
+886	PLEASE DON&#39;T LIKE THIS COMMENT IF YOU ARE WATCHING IN 2015!!!!!!!!!!!	1
+887	The best song ever!	0
+888	The best Song i saw ❤️❤️❤️❤️❤️❤️❤️❤️😍😍😍😍😍😍😍😘😘😘😘😘😘😘😘	0
+889	Party rock	0
+890	I learned the shuffle because of them	0
+891	Tuto to subscribe to my channel because you should sign up for 17 l please thank you I&#39;d do anything for you to sign up a lot of good video I usually do!	1
+892	every bady yust have a good time	0
+893	Best for partying 	0
+894	Take a look at this video on YouTube:	1
+895	Love your song makes me happy	0
+896	why does the world not shuffle???	0
+897	Why did they stop their career in music? This music rocks !	0
+898	Why do I feel like as if Gangnam style copied their song from this?!	0
+899	5th most viewed video.. i guess	0
+900	Like this comment if you are watching on a phone	1
+901	At 500 subscribers i&#39;m sky diving help me reach my goal &lt;3<br />Trust me, i&#39;m a doctor.  :)	1
+902	Check out this video on YouTube:	1
+903	This is my favorite song ever love this Party Rock It Everybody!!!!	0
+904	Love this song makes me wanna dance! 	0
+905	I love this song so much	0
+906	THUMBS UP FOR ROBO GUY BABY	0
+907	Very Nice !	0
+908	EVERYONE PLEASE SUBSCRIBE TO MY CHANNEL OR CAN YOU ALL JUST GO LOOK AT MY VIDEOS 	1
+909	like this comment please	1
+910	still.. this reminds me of 1 years back when i was do shuffle everyday	0
+911	when i see this back in 2015 i ask myself how people got to like this song. seems like Gangnam style copied this style though, might just be me but yea	0
+912	Like this comment if you still jam out to this song after 4 years 	1
+913	LIKKEE	0
+914	Check out this video on YouTube:	1
+915	this video has 800+m views<br />and the channel got 3m subscribers	1
+916	hi guys please check out my vids , i will promos to subscribe to you to	1
+917	Check out this video on YouTube:	1
+918	Omg can this be the first video on YouTube to hit 1 billion views like this comment of you agree.	1
+919	Check out this video on YouTube:	1
+920	Party rock due and duel	0
+921	Check out this video on YouTube:	1
+922	subscribe me plzzzzzzz plzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz	1
+923	fucking love it omg :v	0
+924	Like this comment for no reason	1
+925	Need money ? check my channel and subscribe,soon will post how to get it )	1
+926	This song is just really fun 	0
+927	Check out this playlist on YouTube:	1
+928	I like so much this music,  good 	0
+929	Subscribe	1
+930	░░░░░░░/\░░░░░▄▐<br />░░░░░░/     \░░░▄██▄<br />░░░░░/  (o)   \░░░░░▀█▄<br />░░░░/             \░░░░░░▀█▄<br />░░░/__   \   ░▄▄▄▄▄▀▀<br />░░░░▄▄▄██▀▀▀▀<br />░░░█▀▄▄▄█░▀▀       YOU HAVE BEEN SPOOKED BY THE SPOOKINATI<br />░░░▌░▄▄▄▐▌▀▀▀<br />▄░▐░░░▄▄░█░▀▀    SHARE THIS TO 666 PEOPLE TO BE UNSPOOKED <br />▀█▌░░░▄░▀█▀░▀ <br />░░░░░░░▄▄▐▌▄▄<br />░░░░░░░▀███▀█░▄<br />░░░░░░▐▌▀▄▀▄▀▐▄<br />░░░░░░▐▀░░░░░░▐▌<br />░░░░░░█░░░░░░░░█<br />░░░░░▐▌░░░░░░░░░█<br />░░░░░█░░░░░░░░░░▐▌	1
+931	Check out this video on YouTube: I 	1
+932	No one makes me wanna party like LMFAO does... I just wanna rage every time one of these songs comes on. The only other band that does that is &quot;On the Rocks Inc.&quot; those kids know how to party	0
+933	you cant stop the shuffle	0
+934	Subscribe My Channel	1
+935	Abomination! Subscribe if you agree :| party here in my channel. cool lights.	1
+936	Like this comment for no reason	1
+937	Hey guys subscribe to my chanel and i will subscribe back and like all your vids :)<br /><br />	1
+938	Wow;)	0
+939	Check out this video on YouTube:	1
+940	2015<br />I like video	0
+941	Check out this video on YouTube:	1
+942	Very Nice !!!<br />Yeah Fucking.	0
+943	Ahhh back when my life didn&#39;t suck...	0
+944	Its funny that Mahogany is there lmao, I actually didn&#39;t know that her brother and uncle are part of LMFAO	0
+945	Check out my music niggas	1
+946	SUBSCRIBE MY CHANNEL PLEASE LOL PRO PLAYS)	1
+947	Yeah! Let&#39;s start the party!	0
+948	Can i get views and subscribers for no reason? 😅	1
+949	Check out this funny video &quot;Cereal Box Knocks out Baby&quot; on my channel.	1
+950	Check out this video on YouTube:	1
+951	Check out this video on YouTube:	1
+952	I love this song so much	0
+953	Best Music Ever!!!	0
+954	Super awesome video<br />	0
+955	Check out this video on YouTube:	1
+956	Check out this playlist on YouTube:	1
+957	This song is just insane.<br />Do you dance listening to this song?( i do, lol)	0
+958	Remeber when this song was good	0
+959	like if ur watchin in 90000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000                                                                                              1	1
+960	Way was their a guy warring a robot head.	0
+961	Dang Dat little kid gat moves	0
+962	Check out this video on YouTube:	1
+963	I came here because of Vanoss.	0
+964	I know it old song but (like) if you watching in 2015	0
+965	Check out this video on YouTube:opponents mm <br /><br /><br /><br />--•[••••=====++¥¥£££<br />	1
+966	Check out this video on YouTube:	1
+967	Check out this video on YouTube:	1
+968	This Song is AWESOME!!!!	0
+969	Check out this video on YouTube:	1
+970	HOW DO YOU PUT A PICTURE FOR YOUR IMAGE THINGIE?!?!	1
+971	subscribe to <span class="proflinkWrapper"><span class="proflinkPrefix">+</span><a class="proflink" href="https://plus.google.com/104999962146104962510" oid="104999962146104962510">SuperMarioLogan</a></span>  if you thot the robot dudes r epic and awsome	1
+972	Support the fight for your 4th amendment right to privacy in your home and business. Stop the NSA spying on Americans with the un Patriot Act Renewal. Rand Paul has spent 10.5 hours on the Senate floor in a Protest and Filibuster fighting for our  Constitution that this Nation is founded on. Join the fight at Rand Paul dot com. Spread The Word. We Have Someone That Cares About Our Nation.  Email your Senators, Congress men and women, tell them to support Rand. Tell the news to support Rand too. Senator Rand Paul was up until <a href="http://www.youtube.com/watch?v=KQ6zr6kCPj8&amp;t=1m00s">1:00</a> am this morning fighting for our Constitution.	1
+973	Check out this video on YouTube:	1
+974	like the songs	0
+975	Love this song makes me wanna dance! 	0
+976	The best song ever!	0
+977	I hate it when Laura Bennett comes in	0
+978	Check out this video on YouTube:	1
+979	Love this song makes me wanna dance! 	0
+980	Ummm... I just hit 1k subscribers. I make Minecraft videos. Help me out by checking me out?	1
+981	Memories	0
+982	WELL THIS DUDES FADED THE FUCK OUT!	0
+983	I remember when this used to be so popular all around at the time.<br />I still love it.	0
+984	sorry to all my haters<br /><br /><br /><br /><br /><br /><br /><br /><br /><br /><br /><br /><br />for party rock en	0
+985	Check out this video on YouTube:<br /><br /><br /><br />	1
+986	Check out this playlist on YouTube:<br /><br />	1
+987	White people are going extinct for more information subscribe to my channel or search for videos on &quot;white genocide&quot;  thank you	1
+988	EVERYONE PLEASE GO SUBSCRIBE TO MY CHANNEL OR JUST LOON AT MY VIDEOS	1
+989	Check out this video on YouTube:	1
+990	Very good! Like! :D	0
+991	😼👍😏 Like This Comment 😏👍😼	1
+992	Check out this video on YouTube:	1
+993	SUBSCRIBE me. if you do that leave your name so i can subs back 	1
+994	Check out this playlist on YouTube:a	1
+995	Check out Melbourne shuffle, everybody!	1
+996	Check out this video on YouTube:	1
+997	Check out my dance videos!! You won&#39;t be disappointed!! <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23KingLoTheDancer">#KingLoTheDancer</a>	1
+998	Looooooooove this song!!!!!!!!!:))))))	0
+999	Love these guys, love the song!	0
+1000	Check out this video on YouTu	1
+1001	LMFAO best songs ever!	0
+1002	Check out this video on YouTube:	1
+1003	Check out this video on YouTube:	1
+1004	love lmfao party rockin keep it going	0
+1005	Check out this video on YouTube:	1
+1006	Awsome<br />	0
+1007	cool cool cool cool cool cool cool	0
+1008	Remeber the good ol&#39; days when songs weren&#39;t about butts.  	0
+1009	hey guys im 17 years old remixer and producer and i want you guys to help me checking my videos i am sure you will love those music if you love them then subscribe so you will know more about me:)	1
+1010	Party time!	0
+1011	Man this song really does get in your bones - one minute I&#39;m nearly dosing off to sleep - the next I&#39;m shuffling like crazy with hot feet XD	0
+1012	This was such an epic track. And the video is awesome!	0
+1013	Check out this video on YouTube:	1
+1014	Check out this video on YouTube:	1
+1015	<a href="https://m.freemyapps.com/share/url/505b0232">https://m.freemyapps.com/share/url/505b0232</a>	1
+1016	Subscribe if you are watching in 2015	1
+1017	i was born in the wrong generation	0
+1018	Hey ! Subscribe  to me for the peace in the world ! ♥	1
+1019	Check out this video on YouTube:	1
+1020	Check out this video on YouTube:	1
+1021	pleas subscribe on me for ps4 games video <br /><i>______________________________</i><br />if you have som tips so contact me on kik or skype<br /><i>______________________________</i><br />Kik: pander26<br />Skype: sander.nicolaysen2<br /><i>______________________</i><br />pleas subscribe on me and Kashoo Gaming	1
+1022	SUBSCRIBE!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	1
+1023	that time in 2011 where this shirt was cool	0
+1024	music yeah	0
+1025	I&#39;m A SUBSCRIBER	1
+1026	LMFAO IS THE BEST	0
+1027	wowwwwwwwwwwwwwwwwwwwwwwwwwwwwww	0
+1028	Check out this playlist on YouTube:	1
+1029	I shuffled while listening to this song. THE ILLNESS IS SPREADING!!	0
+1030	HAHAA THIS DANCE IS TIGHTTTT<br /><br />I know y&#39;all &quot;…ain&#39;t got time for MY demo man&quot;  but check out some of my stuff<br /><br />Slappers on slappers on bangers! Click that link below to peep game! ENJOY. <br /><br /><a href="https://soundcloud.com/rocc-steady/wave-emoji-prod-by-nippylongbottom-cyber-punk">https://soundcloud.com/rocc-steady/wave-emoji-prod-by-nippylongbottom-cyber-punk</a>	1
+1031	Beautiful song	0
+1032	Check out this playlist on YouTube:	1
+1033	Check out this video on YouTube:	1
+1034	LMFAO!	0
+1035	Check out this video on YouTube:<br />Gotta dance and just have a blast every time I hear this song !!!! Just ❥love❥ it!<br />	1
+1036	I like it<br />	0
+1037	Best song ever!!!!	0
+1038	Like this in 2015! :D	0
+1039	Check out this playlist on YouTube:🍴🍴🏄🏄🏄🍴🏄🏄🏄🏄🏊🏊🏊🏊🍴🍴🍴🍴🍴🏂🏂🏂🏂🏂🏆🍸🍸🍸🍸🏆🍻🍗🍵🍟🍟🍟🍟🍴🍕🍕🍕🍕🍕🍕🍕🍕🍕🍕🍕🍕☕️🎣🎣☕️🍕🍕🎣🎣🎣🎣🎣🎣🎣🎣🎣🎣☕️🎣🍕🍔🍔🎣🎣☕️🎣🍹🍹🏂🏂🍹🎿🏆	1
+1040	 <br />Please help me get 100 subscribers by the end of the night. Thx	1
+1041	I just wanna see how many people like this comment. I&#39;ll give it a month or so :)	1
+1042	We can have a party next share	1
+1043	Check out this playlist on YouTube:	1
+1044	want a sub? tell me about your channel and i will subscribe (with a couple exceptions)	1
+1045	And i´m Shufflin still today :D	0
+1046	Thumbs up if FE-FE-FE-FE-FEGELEIN brought u here	1
+1047	Loves it	0
+1048	Likeeee	0
+1049	Party Rock....lol...who wants to shuffle!!!	0
+1050	<a href="http://www.gofundme.com/Helpmypitbull">http://www.gofundme.com/Helpmypitbull</a> Can you please donate to help my pitbull PLEASE!!! I just need 50 dollars to take it to the vet!!!	1
+1051	Check out this video on YouTube:	1
+1052	Like this comment for no reason.	1
+1053	Check out this video on YouTube: 	1
+1054	Ah. Good old times (:	0
+1055	wierd but funny	0
+1056	Thumbs up if you watched it in 2011	0
+1057	Help me get 10000000 subscribers by tomorrow!<br /><br /><br /><br /><br /><br /><br /><br /><br /><br /><br />(Joking don&#39;t get butt hurt)  	1
+1058	Subscribe to my channel !	1
+1059	Hello from Russia comes to the channel subscribe	1
+1060	NICE GIRL :D	0
+1061	Let get this video to one billion views	0
+1062	Check out this video on YouTube:lo	1
+1063	Its 2015 and still shuffling to this song🎶🎵🎧	0
+1064	Nice	0
+1065	BEST PARTY SONG LITERALLY PARTY ROCK IS IN THE HOUSEE TONIGHT!!!!	0
+1066	Check out this video on YouTube:	1
+1067	Lol check out my chanell and subscribe please i want 5000 subs thats it im nearly their now	1
+1068	Check out this video on YouTube:      <br />	1
+1069	like this comment if ur watching this on 2015 	1
+1070	Hey plz check out my music video. Thanks!! :-)	1
+1071	Its funny because I listen to rock and death metal. But i like this.	0
+1072	Yo like what up this song is fricking beast anywon herd the see mee rollin or I woke up in a new Buggti check out nuketown rap I was shooting in it see my nice quick scoping things training for faze adoult freind sang I was at my freinds house doing it we just had the rap going and we were doing fun things 	1
+1073	NICE :3	0
+1074	2015 &lt;3	0
+1075	Check out this video on YouTube:	1
+1076	I like this song<br />	0
+1077	Good video	0
+1078	Check out this video on YouTube:	1
+1079	OH SHIT THIS WAS UPLOADED ON MY BIRTHDAY, finally my birthday isnt a curse, because on this day maylaysia airlines went missing, so i thought my birthday was a curse xD	0
+1080	Do not like this comment if you are not watching in 2004	1
+1081	Check out this playlist on YouTube: I tried	1
+1082	Men stop being naive idiots believing in love and believing that women actually respect you.  Want a more fulfilling drama fee life? Check out MGTOW. There are many channels here on YouTube that you can chose from which will give you all the information needed. (Caution: Not for men who enjoy being disposable providers of utility)	1
+1083	LIKE AND SUBSCRIB IF YOU WATCH IN 2015 ;)	1
+1084	the views...... They&#39;re over 90,000!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	0
+1085	<a href="https://m.freemyapps.com/share/url/10b35481">https://m.freemyapps.com/share/url/10b35481</a>	1
+1086	EVERYONE PLEASE SUBSCRIBE TO MY CHANNEL OR CAN YOU ALL JUST GO LOOK AT MY VIDEOS 	1
+1087	Check out my YouTube channel I can rap	1
+1088	this very good so I dance with some companions front of the whole school	0
+1089	5 years... 	0
+1090	Check out this playlist on YouTube:	1
+1091	Check out this playlist on YouTube:	1
+1092	Damn, this was everywhere	0
+1093	IIIIIIIIIII LOVE THIS SHAKE IT SONG OH SORRY EVERY SHAKE IT SONG I LIKE WATCH SUBSCRIBE AND NEVER UNLIKE BROOOOO!!!!!!!!!!! SHAKE IT UP	1
+1094	Check out this playlist on YouTube:pl	1
+1095	Check out this video on YouTube	1
+1096	Best song ever	0
+1097	Subscribe to my channel <br />Tweet &amp; Follow me on twitter //therealterrell_ <br />And I will follow you back 👀	1
+1098	SUP GUS THIS IS A VIDEO FOR PEOPLE WHO LOVES PARTY ROCK SO THANKS FOR WATCHING AND PLEASE SUBSCRIBE:)!!	1
+1099	4 fucking years are fucking past so fucking fast fuck.....	0
+1100	Check out this video on YouTube:	1
+1101	I like how the robot shuffles he shuffles good	0
+1102	Party in da 🏠 tonight 👐👐👐👐👐👐	0
+1103	check it out free stuff for watching videos and filling surveys<br /><br /><a href="http://www.prizerebel.com/index.php?r=1446084">http://www.prizerebel.com/index.php?r=1446084</a>	1
+1104	Nice	0
+1105	this song is awesome. these guys are the best. love this video too its hilarious lol. im getting popular fast because i rap with meaning. thumbs up if you piss next to the water in the toilet so it doesnt make noise	1
+1106	CUTE  :)	0
+1107	Check out this video on YouTube:Qq	1
+1108	Hey guys, I&#39;m a human.<br /><br /><br />But I don&#39;t want to be a human, I want to be a sexy fucking giraffe.<br /><br /><br />I already have the money for the surgery to elongate my spinal core, the surgery to change my skin pigment, and everything else! Like this post so others can root me on in my dream!!!!<br /><br /><br />Im fucking with you, I make music, check out my first song! <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23giraffebruuh">#giraffebruuh</a>	1
+1109	Check out this video on YouTube:	1
+1110	Youtube comments in a nut shell:<br /><br />.First<br />.301 club<br />.Skip to <a href="http://www.youtube.com/watch?v=KQ6zr6kCPj8&amp;t=3m57s">3:57</a> and close your eyes<br />.Advertisements<br />.like this comment for no reason<br />.Christianity arguements<br />.Other religious arguements<br />.Console wars<br />.#PCMASTERACE<br />.Trolls<br />.&quot;How is there 1 million views on the this video if theres online 10 people on earth/<br />.Complaints<br />.Pewdiepie fangirls<br />.Minecraft scrubs<br />./r/MontageParodies<br /><br /><br />AND ANY OTHER SHIT!!	1
+1111	Cool	0
+1112	Check out this video on YouTube:	1
+1113	awesome	0
+1114	please suscribe i am bored of 5 subscribers try to get it to 20!	1
+1115	Check out this video on YouTube:	1
+1116	Check out this video on YouTube:	1
+1117	Hey I&#39;m a British youtuber!!<br />I upload Weekly!! <br />It would mean the world if you could subscribe!!<br />Thanks,Joyce!!<br />	1
+1118	Best song ever made i swear :D i still hear even doe it old!! who else?	0
+1119	Check out this video on YouTube:	1
+1120	LMFAO - Party Rock Anthem ft. Lauren Bennett, Goo…: <a href="https://youtu.be/KQ6zr6kCPj8">https://youtu.be/KQ6zr6kCPj8</a>     <br />BOX MAN GOT SOME MOVES! :D HE MAKES MEH SMILE	0
+1121	Omg it&#39;s going to have 1bi views!	0
+1122	Love this video and the song of course	0
+1123	Thumbs up if shrek is gay 👍	1
+1124	Best song for ever💜💜😢<br />	0
+1125	Never get old 	0
+1126	I like This Comment and do not kill :P	1
+1127	Check out this playlist on YouTube:m	1
+1128	Check out this video on YouTube:<br />	1
+1129	okay, this should cover me for some time... Thumbs up if you&#39;re watching while youtube still exists.	1
+1130	lets get it to 1 BILLION	0
+1131	tension⤴︎⤴︎	0
+1132	PARTY ROCK (8) ~	0
+1133	Check out this video on YouTube:	1
+1134	Strong messages in every  song I&#39;ve  heard.	0
+1135	Check out this video on YouTube:	1
+1136	great l subscribe	1
+1137	Check out this video on YouTube:	1
+1138	Check out this video on YouTube:	1
+1139	Wow justin Bieber is Better  thats why when he buys medication he always shares with his half wited money alfred but sadly enough he is an attention hog with swamp ass and an eating disorder filled with sassy mice, and flaming hot cheetos that he can eat with the power of the samurman.	1
+1140	Check out this video on YouTube:	1
+1141	Hey check out my channel!!!! Please	1
+1142	i want to smack this year boy in to forever	0
+1143	super rihanna	0
+1144	 subscribe to my feed	1
+1145	I love you!❤✨	0
+1146	Hi everyone. We are a duo and we are starting to record freestyles and put them on youtube. If any of you could check it out and like/comment it would mean so much to us because we love doing this. We may not have the best recording equipment but if you listen to our lyrics and rhymes I think you&#39;ll like it. If you do then please subscribe and share because we love making these videos and we want you to like them as much as possible so feel free to comment and give us pointers! Thank you!	1
+1147	i love song :)	0
+1148	SUBSCRIBE TO MY CHANNEL X PLEASE!. SPARE	1
+1149	check out my new EM cover video trailer	1
+1150	I lovet	0
+1151	Rihana, Love Me. :(	0
+1152	Check out Berzerk video on my channel ! :D	1
+1153	She looks like Megan Fox 😂 xD!!	0
+1154	GUYS SHARE THIS VIDEO!!!!!  HE must have 1.000.000.000 views  Share on Facebook, groups, pages!!	1
+1155	im M.E.S an aspiring young rapper with high hopes.  i know you will not click this link and check out my channel. you just dont want to.  well its your choice. in fact you know what DONT CLICK	1
+1156	Almost a billion	0
+1157	Great song	0
+1158	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at MONEYGQ.COM ! Visit MONEYGQ.COM and check it out! Memory Ferirama Besloor Shame Eggmode Wazzasoft Sasaroo Reiltas Moderock Plifal Shorogyt Value Scale Qerrassa Qiameth Mogotrevo Hoppler Parede Yboiveth Drirathiel	1
+1159	I shared my first song &quot;I Want You&quot;, and I hope you&#39;ll like it. Take a listen!✌️😊 <a href="https://youtu.be/-YfUY4gKr1c">https://youtu.be/-YfUY4gKr1c</a>	1
+1160	<a href="https://www.facebook.com/groups/100877300245414/">https://www.facebook.com/groups/100877300245414/</a>	1
+1161	This past Christmas my dad passed away, to help me cope I picked up a pen and some paper and started to write. Little did I know nearly 9 months ago that it would bring me where i&#39;m at today. Over 7k subscribers and 500k video views. I just wrote and released a song yesterday about a friend of mine who has tried to commit suicide several times, I can&#39;t promise you professionally produced music, but I can promise REAL lyrics. Just click on my picture. Thanks, and thumbs up so others may see?	1
+1162	EMINEM the best EVER.	0
+1163	LOVE TROP FORT VOTRE  clip	0
+1164	Feels and emotions in this song...God damn	0
+1165	2015 but Im still listening to this!	0
+1166	The perfect example of abuse from husbands and the thing is I&#39;m a feminist so I definitely agree with this song and well...if I see this someone&#39;s going to die! Just sayin.	0
+1167	Charlie from LOST	0
+1168	No-I hate The Way U LIe!!	0
+1169	Go check out eminem survival.	1
+1170	Check out this playlist on YouTube:	1
+1171	love the you lie the good	0
+1172	Every weekend a new lyric video of Eminem here VIEW LIKE SUBSCRIBE	1
+1173	Also check out D.j.j where do i go now and road to recovery	1
+1174	check out my Eminem &amp; Kid Cudi M a s h up /watch?v=XYTcq5NzMuA	1
+1175	Look and shares my video please :D	1
+1176	Hey! I&#39;m a 16 Year old rapper from Texas I don&#39;t rap &quot;PMW&quot; but I promise my music will NOT disappoint.. Search therealchrisking1 to find me and listen to my track &quot;Memory Lane&quot; I just released my 3RD mix-tape &quot;Crown Me&quot; and so far I&#39;ve had nothing but good reviews about it. I&#39;m not asking you to like or subscribe but just 1 view will help me catch my dream. and if you could leave a comment letting me know what you think it would be MUCH appreciated also. Thank you.	1
+1177	plese subscribe to me	1
+1178	check out my channel for rap and hip hop music	1
+1179	Eminem is idol for very people in España and Mexico or Latinoamerica	0
+1180	LOVE THIS SONG!!!	0
+1181	I love this-the talents of eminem and Skylar,works well together	0
+1182	SUBSCRIBE TO ME! I MAKE MUSIC!	1
+1183	RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX) RIHANNA - POUR IT UP (VINCENT T. REMIX)  CLICK! SUBSCRIBE!	1
+1184	share and like this page to win a hand signed Rihanna photo!!! fb -  Fans of Rihanna	1
+1185	Check out this playlist on YouTube:	1
+1186	Charlie, heroin will do that to you.	0
+1187	LOVE IT!!!!!!!	0
+1188	:D subscribe to me for daily vines	1
+1189	Check out this video on YouTube:	1
+1190	Wow this video almost has a billion views! Didn&#39;t know it was so popular 	0
+1191	Check out this video on YouTube<br /><br /><br />	1
+1192	Share Eminem&#39;s Artist of the Year video so he can win. We could see a performance and acceptance speech. Like this comment so everyone can see.  2014 =  Year of Eminem	1
+1193	<a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23Awesome">#Awesome</a> <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23Share">#Share</a> <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23RT">#RT</a> Eminem - Love The Way You Lie ft. Rihanna <a href="http://ow.ly/2zME8f">http://ow.ly/2zME8f</a>	1
+1194	Check out this video on YouTube:	1
+1195	You guys should check out this EXTRAORDINARY website called ZONEPA.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at ZONEPA.COM !   Visit Zonepa.com and check it out!  Why does the view disclose the macho lift? Why does the letter frame the thundering cause? Why does the talk prevent the conscious memory?	1
+1196	Check out this video on YouTube:	1
+1197	Awesome song!,congratulations!!!	0
+1198	Do you need more instagram followers or photo likes? Check out IGBlast.com and get em in minutes!	1
+1199	I guss this song is one of my worst fears in life, to be with someone who abusive towered me and live with him.... 	0
+1200	Subscribe To Mê Please Guys	1
+1201	Check out my channel for some lyricism.....	1
+1202	2008-2010 were the years for pop	0
+1203	Hey guys please just spent one minute for me:) im a singer from srilanka and im 18 year&#39;s old! SooO I LIKE TO BE GREAT SINGER will be a one day so i hope it would be happen in really my life. Just view my videos and rate it and let me know how is my voice please help us guys closer to living my dream:)! I really know you&#39;d helped me who read this comment he/she can just understand my feeling ! Thanks guys lv u really much more! SUBSCRIBE IT MY CHANNEL AND ONECE AGAIN THANK U !	1
+1204	Eminem and Rihanna sing the song very well.	0
+1205	I love your songs eminem your the rap god	0
+1206	could you guys please check out my channel for hiphop beats?	1
+1207	Eminem - Love the way you lie ♥ ♥ ♥	0
+1208	YO GUYS SORRY IF THIS ANNOYS YOU!!! BUT CHECK OUT MY CHANNEL, AND LATEST VIDEO. A LIKE AND SUBSCRIBE WOULD BE NICE TOO!!!!!!!	1
+1209	adam b beats check out my page 2013	1
+1210	check out my page ADAM B BEATS 2013	1
+1211	:D	0
+1212	First they were fighting... Then they were making out...	0
+1213	CHECK OUT MY COVER OF LOVE THE WAY YOU LIE PLEASE!!	1
+1214	Almost to a billion :)	0
+1215	Check out this video on YouTube:	1
+1216	I lover this song	0
+1217	best song ever (y)	0
+1218	This song is about Rape and Cheating     <br /><br /><br /><br /><br /><br /><br /><br /><br /><br />Basically.....	0
+1219	Hi loving it	0
+1220	eminem - RIHANNA	0
+1221	Omg! This guy sounds like an american professor green	0
+1222	Check out this video on YouTube:	1
+1223	Lol thats the guy from animal planet and lost. And the girl is megan fox i think hahah 	0
+1224	I agree they are just damn spammers. They suck. Im trying to do the same and get known. I know the feeling. I will help you out and leave likes and comments on your page. I hope you could do the same and maybe subscribe to me that would be awesome thanks	1
+1225	Hi I am from bangladesh 💜	0
+1226	I&#39;m subscribing to you just because your comment was that great.	1
+1227	eminem new song check out my videos	1
+1228	I like the Mmph version better	0
+1229	I hate rap and I like this song	0
+1230	I  love you Eminem	0
+1231	Hello to everyone! Please check out my video: /watch?v=2b4WyWpHi8c It takes just 2 minutes, hope I don&#39;t ask too much. If you can&#39;t help me just Like &amp; Share the video please, so more people will see it. Thumbs UP, it would mean so much to me, maybe will make my dream come true. Thanks to everyone, GOD BLESS YOU ALL!	1
+1232	Don&#39;t love someone soo much, love the way u lie..	0
+1233	Eminem rocks!	0
+1234	 Hey youtubers... {**}I really appreciate all of you who took the time, to read this, I am just a 19 year old boy who wants to be a successful musician in the           music world. {**}I dont have any money to advertise my channel, {**}If you could just visit my channel, comment on my video or subscribe, that would be great.... {**}It will only be few seconds of your life..... {**}Thank u to all the people who just gave me a chance l really appreciate it 	1
+1235	this song is better then monster by eminem	0
+1236	like please	1
+1237	You guys should check out this EXTRAORDINARY website called FIREPA.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at FIREPA.COM !   Visit FIREPA.COM and check it out!   Lake   . Busyglide . Sasaroo . Sore . Chillpal . Axiomatic . Naperone . Mere . Undesirable . Agreeable . Encouraging . Imperfect . Roasted . Insidious . Modgone . Quickest . Trelod . Keen . Fresh . Economic . Bocilile	1
+1238	all u should go check out j rants vi about eminem	1
+1239	that is megan fox	0
+1240	Was that Meghan fox??	0
+1241	Hey everyone, I&#39;m Dakoda Bigelow.  I&#39;m a 17 year old singer, rapper and producer.  I&#39;ve gained over 490,000 views so far but I really need more views, likes and subscribers in order to make it.  Check out my brand new music video for my song So Gone!  Also check out my new song called Deaf.  A few moments of your time would be greatly appreciated and you won&#39;t regret it :)  Please thumbs this up so more people can see it!  Thanks for the support.	1
+1242	Eminem is my insperasen and fav	0
+1243	beautiful song!	0
+1244	*for 90&#39;s rap fans*  check out my Big Pun - &#39;Beware&#39; cover!  Likes n comments very much appreciated!	1
+1245	Subscribe me Secret videos :D	1
+1246	❤❤❤❤❤❤❤	0
+1247	Who is still watching in 2015	0
+1248	Really good song .<br />you know love song song.	0
+1249	This song/video is such a trigger but it&#39;s just so good...	0
+1250	hey guys i know you wanna skip over this but please take a chance and check out my music I&#39;m a young up and coming rapper with a big dream and i appreciate constructive critisism. so please have a look thank you for your time	1
+1251	Still listening,still same pleasure	0
+1252	So freaking sad...	0
+1253	Hello. I only made ONE Eminem song on my channel. Could you guys just put A LIKE on it please? Don&#39;t need to subscribe. Thanks.	1
+1254	Media is Evil! Please see and share: W W W. THE FARRELL REPORT. NET  Top Ex UK Police Intelligence Analyst turned Whistleblower Tony Farrell exposes a horrific monstrous cover-up perpetrated by criminals operating crimes from inside Mainstream Entertainment and Media Law firms. Beware protect your children!! These devils brutally target innocent people. These are the real criminals linked to London&#39;s 7/7 attacks 2005.  MUST SEE AND MAKE VIRAL!!! Also see UK Column video on 31st January 2013.	1
+1255	is that megan fox x:D?	0
+1256	do you guys know, there&#39;s a part two of this song! :D	0
+1257	Hey Go To My Channel Check Out My Dongs Thanks YouTuber&#39;s	1
+1258	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!  When does the flimsy slip facilitate the manager? How does the noise set goals the anxious regret? How does the actually loss retrieve the smile?	1
+1259	Check out this video on YouTube:	1
+1260	Like &amp; Subscribe /watch?v=5tu9gN1l310	1
+1261	Is that girl is Megan fox 	0
+1262	amazing song	0
+1263	Sick Music for sick females	0
+1264	You guys should check out this EXTRAORDINARY website called ZONEPA.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at ZONEPA.COM ! Visit Zonepa.com and check it out! How does the war illustrate the exclusive space? The mountain refers the helpless death. The tax reviews the special music.	1
+1265	+447935454150 lovely girl talk to me xxx	1
+1266	We need to get this to 1 Billion Views!!	0
+1267	Fuck you Eminem	0
+1268	Charlie got off the island and dated Megan Fox? b-but what about claire?	0
+1269	Check Out The New Hot Video By Dante B Called Riled Up	1
+1270	wtf. subscribe my channel thanx ;)	1
+1271	Please check out my New Song (MUSIC VIDEO) AD - Dont Play	1
+1272	awesome	0
+1273	Go check out my rapping video called Four Wheels please ❤️	1
+1274	this is the 4th most watched video on youtube. and hes the 21 most subscribed channel on youtube, just recentley surpassing justing bieber and is less then 1,00 subs away from beating taylor swift	1
+1275	Please check out my New Song (Music Video) AD - Don&#39;t Play	1
+1276	I could hear this for years ;3	0
+1277	YO GUYS IM 14 YEAR OLD RAPPER JUST STARTED RAPPING  SO PLEASE CHECK OUT MY SITE AND LEAVE FEEDBACK AND SUBSCRIBE  ALSO LIKE THIS COMMENT SO MORE CAN SEE AND MAKE MY CHANNEL BIG ASWELL	1
+1278	 I can&#39;t believe this I just now watched a brand new adult video clip with rihanna She has been screwing a black Basketball player  Check out her video right here if you wish:   crazy-celeb-news.eu.pn	1
+1279	No long paragraph just check out my song called &quot;Fire&quot;.	1
+1280	Hello Brazil 😻✌💓😻👏	0
+1281	Holy crap. 800,000,000 views?!	0
+1282	charlieee :DDDD (Those who saw Lost only will understand)	0
+1283	Looooved 	0
+1284	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!  The cook officiates the tax. The walk judges the amount. Why does the ink train the valuable increase?	1
+1285	check out fantasy music    right here -------&gt; the1fantasy  good music man.	1
+1286	hey guy if you can please SUBSCRIBE to my channel im a young dedicated rapper i post videos everyday to improve  i write/perform/record/mix/edit/post everyday  i do a verse everyday (16 bars)to improve i&#39;m doing it for 365 days a whole year to improve right now i&#39;m on day 46  if you guys can please like this comment so everyone can see it and follow me on my journey/watch me improve everyday SUBSCRIBE PLEASE I&#39;m lyrical and i keep it real  help me reach my dream,help me build a fan base,THANKS(:	1
+1287	Check Out LEANDRUS - PLAYTIME It&#39;s awesoooome !!	1
+1288	if eminem gets 1 penny per view he would have 600 million dollars	1
+1289	Check out this playlist on YouTube:chcfcvzfzfbvzdr	1
+1290	:)	0
+1291	Tell us the title so i can like and subscribe your music fgw please	1
+1292	Come check out our music channel! go check out the song &quot;Love Me&quot; by J.E.M.INI! If you do you are a very amazing person, thank you to all who check it out, if you don&#39;t check it out you&#39;re still an amazing person just for reading this post	1
+1293	Alright ladies, if you like this song, then check out John Rage.  He&#39;s a smoking hot rapper coming into the game.  He&#39;s not better than Eminem lyrically, but he&#39;s hotter. Hear some of his songs on my channel.	1
+1294	Check out this video on YouTube:<br /><br />Love this song... It&#39;s all good will never go back that but I&#39;ll always remember the passion but never want to go back to being dysfunctional insanity....... Goal is to live happy not live insane. 	1
+1295	Check out this video on YouTube:<br /><br />Eminem is amazing. 	1
+1296	I wish that guy wasn&#39;t so protective geeze	0
+1297	You guys should check out this EXTRAORDINARY website called ZONEPA.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at ZONEPA.COM !   Visit Zonepa.com and check it out!  Why does the answer rehabilitate the blushing limit? The push depreciateds the steel. How does the beautiful selection edit the range?	1
+1298	Made five years ago and people still don&#39;t understand the message this song is conveying. For those of you who don&#39;t, understand that domestic violence is no game, women like the fictional one in this video do exist. Living their lives with a man who abuses her, but still chooses to stay for not the love, but for the sexual feeling he gives her. This is quite sad, actually. 	0
+1299	▬▬▬▬▬▬▬▬▬▬ஜ۩۞۩ஜ▬▬▬▬▬▬▬▬ CHECK OUT MY CHANNEL ▬▬▬▬▬▬▬▬▬▬ஜ۩۞۩ஜ▬▬▬▬▬▬▬▬	1
+1300	Usually guys I would vote people like this down but actually check out his channel, it reminds me of a smaller version of 8 mile. Give this guy a chance since everyone has to start somewhere!	1
+1301	Not bad	0
+1302	hay my is honesty wright i am 12year old  i love your song thank you for makeing this song i love this song so much sometime harts can get breaken  people kill  they self or go cazzy i  love you so much thanks 😱👏keep on go  make  sure you doing your   dream is comeing rule  good luck	0
+1303	CHECK OUT THESE LYRICS /watch?v=yUTTX04oyqQ	1
+1304	everyone come and check out the new GTA 5 Gameplay right Here : /watch?v=6_H0m5sAYho	1
+1305	yo I know nobody will probably even read this..  But I’m gonna type it any way because i hope at least one person will  i&#39;m a rapper with a dream. I know there&#39;s like 200k of those in this world  but please check out my music and subscribe? if you want.? thanks,  I would love nothing more than to have a decent following on youtube..  if anyone? reading this could press the &quot;THUMBS UP&quot; other people will see it  just a simple button press? could make my dream come true =) Thank You	1
+1306	Love Song	0
+1307	Every collaboration between them, we know it will be number 1 	0
+1308	Fuck Eminem. Bieber is the best &lt;3	0
+1309	Check out this playlist on YouTube:	1
+1310	Rihanna is absolutely gorgeous in this video.	0
+1311	who the fuck cheats on megan fox	0
+1312	hello friends. i am a young 15 year old rapper trying to make something out of nothing..Please can you take a second of your life and check out my videos and help me reach my Dreams! GoD Bless YOU	1
+1313	e.e....everyone could check out my channel.. dundundunnn	1
+1314	Hey everyone check out my channel leave a like and subscribe please and if there is a song you want me to do post the name in the comments and I will get on to it(: Thanks	1
+1315	I love eminem  &lt;3	0
+1316	Check out this video on YouTube:	1
+1317	2010? The time past so fast ..	0
+1318	What the hell this song is already five years old?? I remember when it first came out, this was my jam	0
+1319	That guy charley of lost TV show	0
+1320	Check out this video on YouTube:	1
+1321	1,000,000 VIEWS NEAR	0
+1322	Love this song	0
+1323	WE GO FOR 1,000,000,000 FOR EMINEM	0
+1324	I love this song, can&#39;t believe it was 5 years ago, it doesn&#39;t get old though	0
+1325	What nicei⛺♥♥♥♥	0
+1326	My friends wife earns 4000DOLLARS a month ,you can do it do if you want to be a wwhore ,you can not get these type of wages for working from home on a PC taking surveys ...dont believe that shit,just because yahoo and other are taking their money does not mean they are LEGIT ..Please like this so the MSG gets thru to vulnerable  people Like eminem used to be 	1
+1327	/watch?v=aImbWbfQbzg watch and subscrible	1
+1328	Hi guys ! !  Check Out My Remixes ! !  Thanx You&#39;re Great ! ! SWAG ! !	1
+1329	Megan Fox is gorg in this!! Eminem is truly the rap god :)	0
+1330	Anybody else here in 2015?	0
+1331	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!  The metal drews the reflective effect. Why does the expansion intervene the hilarious bit? The sneeze witnesss the smoke.	1
+1332	is it bad that my realtionship is just like this lol	0
+1333	Check Out The New Hot Video By Dante B Called Riled Up	1
+1334	EMINEM FANS!!!  - Check Out The New Song &quot;FEELIN&#39; GOOD&quot; By J-D-P  Just Click The &quot;GHOSTPOET100&quot; Link Above This Post  Thank You All...	1
+1335	subscribed :) btw you have a good style keep it up brother :))	1
+1336	Check out my videos guy! :) Hope you guys had a good laugh :D	1
+1337	super	0
+1338	Getting too 1billion views, holy moly.!!!	0
+1339	i love Rihanna 😍😍😍😍[♧from Thailand♧]	0
+1340	check out you tube keithlinscotts one word keithlinscotts you tube .com	1
+1341	Hello I&#39;am from Palastine	1
+1342	SnEakiESTG Good Music. Hood Muzik Subscribe 2 My Channel. Thanks For The Support. SnEakiESTG   SnEakiESTG Good Music. Hood Muzik Subscribe 2 My Channel. Thanks For The Support. SnEakiESTG	1
+1343	Me and my big sister like you	0
+1344	You guys should check out this EXTRAORDINARY website called ZONEPA.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at ZONEPA.COM ! Visit Zonepa.com and check it out! How does the mammoth waste achieve the shock? How does the limit reduce the delicate minute? How does the meaty scale adapt the oil?	1
+1345	Who is watching in 2015 like	0
+1346	watch?v=ARkglzjQuP0 Like this comment and share this video so Em can win!!! #YTMA	1
+1347	This song is like an oreo, the black part is good but the white part is better	0
+1348	Hey, it&#39;s Charlie from Lost	0
+1349	I love music	0
+1350	yo I know nobody will probably even read this..  but, Imma type it any way because i hope at least one person will  i&#39;m a rapper with a dream. I know there&#39;s like 200k of those in this world  but please check out my music and subscribe? if you want.? thanks,  I would love nothing more than to have a decent following on youtube..  if anyone? reading this could press the &quot;THUMBS UP&quot; other people will see it  just a simple button press? could make my dream come true =) Thank You	1
+1351	watch youtube video &quot;EMINEM -YTMA artist of the year&quot; plz share to vote!!!	1
+1352	share and like this page to win a hand signed Rihanna photo!!! fb -  Fans of Rihanna	1
+1353	OMG that looks just like a piece of the mirror of harry potter and the deathly hallows.<br />Either that house. (sirius black)	0
+1354	i hate rap	0
+1355	Thumbs up if you listen this in 2015.	0
+1356	Check out this video on YouTube: <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23Eminem">#Eminem</a> <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23Lovethewayyoulie">#Lovethewayyoulie</a> <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23RapGod">#RapGod</a> <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23King">#King</a> 	1
+1357	Check out this video on YouTube:	1
+1358	EVERYONE GO AND SHARE youtu  be/ARkglzjQuP0 ON FB,TWITTER,G+ TO VOTE FOR EMINEM TO BECOME ARTIST OF THE YEAR ON FIRST EVER YOUTUBE MUSIC AWARDS !!!  AND GET THIS METHOD TO CHEAT AT INTERNET ROULETTE OUT OF EMINEMS VIDEO ! SHADY ARTIST OF THE YEAR !	1
+1359	2010:(	0
+1360	like this comment then type 1337	1
+1361	Best. Song. EVER 🙌	0
+1362	COME SUBSCRIBE TO MY CHANNEL! ;-)  PLEASE!!	1
+1363	Check out this playlist on YouTube:	1
+1364	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!  Why does the innocent woman prioritize the desire? The flight searchs the sad polish. When does the tax zip the omniscient record?	1
+1365	LADIES!!! -----&gt;&gt; If you have a broken heart or you just want to understand guys better you should check out this underground book called The Power of the Pussy on AMAZON. Best book ever for us girls! Oh...and just a warning it&#39;s for 18 and over...lol	1
+1366	Check Out The New Hot Video By Dante B Called Riled Up	1
+1367	Hi Guys im an Upcoming Rapper if you could check out my channel and tell me what you think maybe subscribe or like i would really appreciate it all HATERS are welcome to :) thanks	1
+1368	so many comments.	0
+1369	subscribe to my channel who can	1
+1370	Charlie from Lost!	0
+1371	Eminem best rapper all the time	0
+1372	Okay trust me I&#39;m doing a favor. You NEED to check out this guy named Columbus Nelson on YouTube.	1
+1373	sorry but eminmem is a worthless wife beating bastard	0
+1374	You guys should check out this EXTRAORDINARY website called ZONEPA.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at ZONEPA.COM ! Visit Zonepa.com and check it out! How does the burst render the symptomatic bite? The knowledge briefs the narrow thought. How does the eager sky transmit the crush?	1
+1375	Every single one of his songs brings me back to place I can never go back to and it hurts so bad inside	0
+1376	  Eminem is the king of rap  Micheal Jackson is the king of pop  If you also wanna go hard and wanna be the person of first class fame just check out Authenticviews*com and be famous just within days !! yO ~	1
+1377	 HI IM 14 YEAR RAPPER SUPPORT ME GUY AND CHECK OUT MY CHANNEL AND CHECK OUT MY SONG YOU MIGHT LIKE IT ALSO FOLLOW ME IN TWITTER @McAshim for follow back.	1
+1378	CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere CHECK OUT THIS DOPE CHANNEL!    phenomenallyricshere	1
+1379	I love this song sooooooooooooooo much	0
+1380	Almost 1 billion	0
+1381	is that megan fox?	0
+1382	Check Out The New Hot Video By Dante B Called Riled Up	1
+1383	Love you	0
+1384	Check out my channel to see Rihanna short mix by me :)	1
+1385	Lemme Top Comments Please!!	0
+1386	Eminem is the greatest artist to ever touch the mic.	0
+1387	Dress like Rihanna at kpopcity.net - The largest discount fashion store in the world! Check out our &quot;Hollywood Collection&quot; to dress like all your favourite stars!   Dress like Rihanna at kpopcity.net - The largest discount fashion store in the world! Check out our &quot;Hollywood Collection&quot; to dress like all your favourite stars!	1
+1388	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!   Lake   . Ignorant . Wavefire . Reiltas . Astauand . Skizzle . Jovaphile . Swooflia . Grynn . Excellent . Slimy . Gabby . Nalpure . Lucky . Glozzom . Depressed . Better . Deep . Sinpad . Stereotyped . Toximble	1
+1389	Never gets old best song ever  ❤	0
+1390	CHECK OUT MY YOUTUBE VIDEOS FOR FUNNY AND COOL RAP	1
+1391	I love this song up to the moon &gt;3 you are Rocks!	0
+1392	tryna work with some rappers check out the ones i already have on my channel	1
+1393	awesome song ever	0
+1394	CHECK OUT MY MUSIC VIDEO ON MY CHANEL!!!	1
+1395	i love this song	0
+1396	You guys should check out this EXTRAORDINARY website called ZONEPA.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at ZONEPA.COM ! Visit Zonepa.com and check it out! The loud authority improves the canvas. When does the mother repair the uppity learning? The substantial cook derives the harbor.	1
+1397	So he&#39;s admitting he killed his girlfriend???	0
+1398	Help me get 50 subs please 	1
+1399	You exactly who u want to be,watching your favourite rappers on tv	0
+1400	Hello I&#39;m from Bulgaria	0
+1401	fav.	0
+1402	Check Out The New Hot Video By Dante B Called Riled Up	1
+1403	hahahahah ♥♥♥♥ :D like vines ?  Subscribe to me for daily vines	1
+1404	This video is kinda close to 1 million  views <br />	0
+1405	You guys should check out this EXTRAORDINARY website called ZONEPA.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at ZONEPA.COM !   Visit Zonepa.com and check it out!  The jelly activates the reflective distribution. The normal top synthesizes the opinion. The victorious plant entertains the language.	1
+1406	Eminem et Rihana trop belle chanson	0
+1407	 Subscribe and like my video please	1
+1408	Is that Charlie from lost?	0
+1409	Love	0
+1410	subscribe me if u love eminem	1
+1411	I love this song	0
+1412	Is that Megan Fox?	0
+1413	  Check out my SEXY VIDEO :*	1
+1414	love the way you lie featuring rhianna, hes an awesome rapper!!! shes an awesome singer!!!	0
+1415	Hay dakota u earned a subscribee	1
+1416	Check out this video on YouTube:	1
+1417	▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌ FACEBOOK PASSWORD HACK 2013! facebook-pass-hack2013.blogspot.com ONLY 100% WORKING SOFTWARE FOR HACKING ANY FACEBOOK PASSWORD! It&#39;s FREE for download and breaks any password in 10-15 minutes! 100% virus free! For details and FREE DOWNLOAD please visit facebook-pass-hack2013.blogspot.com ▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌▌	1
+1418	CHECK OUT THIS NEW VIDEO I MADE CALLED &quot;WE LOVE MIND MASTER IT&quot;, THANK U SO MUCH 	1
+1419	if you need youtube subscriber mail hermann buchmair on fb	1
+1420	This guy win dollars sleeping... m m m he loves the planet its full of RETARDS	1
+1421	some classsic :))))	0
+1422	I like this song very much	0
+1423	Rihanna looks so beautiful with red hair ;)	0
+1424	I personally have never been in a abusive relationship. I probably never will. I don&#39;t hit women. Mom has my dad used to hit my mom before he left. I can relate I&#39;m writing about one at the moment subscribe to hear it. EVERY fan counts.	1
+1425	hey its M.E.S here I&#39;m a young up and coming rapper and i wanna get my music heard i know spam wont get me fame. but at the moment i got no way of getting a little attention so please do me a favour and check out my channel and drop a sub if you enjoy yourself. im just getting started so i really appreciate those who take time to leave constructive criticism i already got 200 subscribers and 4000 views on my first vid ive been told i have potential	1
+1426	this song is NICE	0
+1427	Aye homies check out our remix to 50 Cent Your Life Is On The Line we just started our youtube channel and we are all ways working hard, give us some feed back on our latest song on what you guys think if you like show support.	1
+1428	eminem is a ginius stop!	0
+1429	This video deserves <b>1B</b> views!!!	0
+1430	song is bad	0
+1431	He gets more views but has less subscribers lol	1
+1432	hey guys if you guys can please SUBSCRIBE to my channel ,i&#39;m a young rapper really dedicated i post a video everyday ,i post a verse (16 bars)(part of a song)everyday to improve i&#39;m doing this for 365 days ,right now i&#39;m on day 41  i&#39;m doing it for a whole year without missing one day if you guys can please SUBSCRIBE and follow me on my journey to my dream watch me improve, it really means a lot to me  thank you (:, i won&#39;t let you down i promise(: i&#39;m lyrical i keep it real!	1
+1433	check out eminem latest track survival if u didnt	1
+1434	CHECK OUT MY CHANNEL BOYS AND GIRLS ;)	1
+1435	Check out our cover of this song!	1
+1436	thumbs up if you think this should have 1 billion views	0
+1437	she is megan fox?because she is very similar	0
+1438	amazing song	0
+1439	i been working so hard for the past 60 days to improve i been writing/recording/mixing/performing 16 bars everyday im doing it for a whole year(365 days) today i&#39;m on day 60 today i work so hard now i&#39;m doing 2 verses everyday to improve i&#39;m going all the way to 365 PLEASE SUBSCRIBE to my channel  help me build a fan base ,watch me improve,help me get closer to my dream i&#39;m lyrical and i keep it real i won&#39;t leave you down i promise(: PLEASE SUBSCRIBE please like this  commment so people can see	1
+1440	dude check out psy	1
+1441	This great Warning will happen soon. ,0 LneaDw26bFst76VHKJL8PxaEy6VMNlvmriUDTSFK6vY,Ali Paša,2013-09-26T22:28:17.047000,Croatia &lt;3,0 LneaDw26bFvkAHxpKEnM25FYWkyXthsUpri6JuQsZnU,G Belrus,2013-09-26T22:26:12.832000,Nice one,0 LneaDw26bFtvZQt6JUEhasIEFRJG1exI_dVqdnQVPho,exode. comeback.,2013-09-26T22:23:00.710000,600m views.,0 LneaDw26bFunOarAg71AwGU6TJO6aZDKFIUn_TZ1_HY,Muhammad Shaeel Abbas,2013-09-26T22:15:45.476000,Fuck off!,0 LneaDw26bFt-oToUFj0z3vffLFNaxyKwZSIVQhiMx-E,Notorious Niko,2013-09-26T22:00:43.613000,"Hey guys im a 17yr old rapper trying to get exposure... I live in belgium where NO ONE speaks english so i have to resort to this gay SPAM...  Check out my 2 latest tracks as they are probably my best.. Audio isnt the best but im gonna invest in some real equipment for my next track..  Please Thumbs this up so others can see.. or hey dont just check me out yourself and leave a response and a like :D  Thanks in advance, you guys will be part of making my dream come TRUE   -Notorious Niko 	1
+1442	hey its M.E.S here I&#39;m a young up and coming rapper and i wanna get my music heard i know spam wont get me fame. but at the moment i got no way of getting a little attention so please do me a favour and check out my channel and drop a sub if you enjoy yourself. im just getting started so i really appreciate those who take time to leave constructive criticism i already got 200 subscribers and 4000 views on my first vid ive been told i have potential	1
+1443	hey guys look im aware im spamming and it pisses people off but please take a moment to check out my music.  im a young rapper and i love to do it and i just wanna share my music with more people  just click my picture and then see if you like my stuff	1
+1444	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at MONEYGQ.COM ! Visit MONEYGQ.COM and check it out! Wazzasoft Industry Sertave Wind Tendency Order Humor Unelind Operation Feandra Chorenn Oleald Claster Nation Industry Roll Fuffapster Competition Ociramma Quality	1
+1445	Eminem THE BEST !	0
+1446	Im gonna share a little ryhme canibus blows eminem away a quadrillion times especially about the categories of intelligent things in his mind. That he learned and rapped about and forgot before eminem spit his first ryme.luv ya linz 4e	1
+1447	Terrance. .thank you for serving our country. How do i &quot;like you&quot; or &quot;subscribe&quot;?	1
+1448	❤️❤️❤️	0
+1449	goot	0
+1450	Rihanna is so beautiful and amazing ♥♥♥♥♥love her so much♥♥♥♥ forever RiRi fan ♥ ♥ ♥ ♥ ♥ ♥ 	0
+1451	Aslamu Lykum... From Pakistan	1
+1452	Is that Charlie from lost?<br />	0
+1453	Subscribe to me for clean Eminem!	1
+1454	SubScribe me pls EMİNEM FANS	1
+1455	WATCH MY VIDEOS AND SUBSCRIBE	1
+1456	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!   Lake   . Magnificent . Noodile . Unequaled . Moderock . Gogopo . Lulerain . Olielle . Zesty . Laughable . Accidental . Pepelexa . Delightful . Wiry . Toogit . Uncovered . Chesture . Woozy . Adhoc . Weak . Shallow	1
+1457	Finally someone shares the same opinion as me. I&#39;ve tried really hard to like Gangnam style but I just can&#39;t, I don&#39;t even get the dance either. A song like this should be more popular than Gangnam Style, at least it makes sense.	1
+1458	Maybe no one will probably read this. But just in case you do Can You Put A &quot;Thumbs Up &quot;So Others Can See. I Just started rapping seriously (Type in Lunden- 1990 Video) just a simple button? can help my dreams come true. people will ignore this but if you don&#39;t can you listen and subscribe Thank you all God Bless.	1
+1459	check out our bands page on youtube killtheclockhd - check out some of our original songs including &quot;your disguise&quot;	1
+1460	This song is true because it is insane because boys will do this to a girl and it is not true that people say that a guy will never do it to a girl but boys YOU LIERS NOW STOP TREATING US GIRLS THIS WAY YOU ALL SUCK!	0
+1461	thumb up if you watching in 2015 and  you like it	0
+1462	i like the lyrics but not to music video	0
+1463	LOVE THE WAY YOU LIE ..&quot;	0
+1464	BR	0
+1465	Check out my channel please.	1
+1466	Love the video 	0
+1467	#1 song in world even in 2015	0
+1468	Hey if you guys wouldnt mind...could you check out my boys and my music...we just made a real lyrical song...search DNA Andrew Guasch...I appreciate it. Leave some real feedback and keep Hip-Hop alive	1
+1469	I KNOW YOU MAY NOT WANT TO READ THIS BUT please do  I&#39;m 87 Cypher an 11 year old rapper I have skill people said .my stuff isn&#39;t as good as my new stuff but its good please check out my current songs comment and like thank you for reading rap is my life	1
+1470	Love😘❤💖	0
+1471	this fucking song like a&#39;n oreo the only white part is the good 1	0
+1472	Check out Em&#39;s dope new song monster here: /watch?v=w6gkM-XNY2M  MMLP2 FTW :)	1
+1473	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!   Lake   . Victorious . Luxuriant . Alcoholic . Responsible . Unbiased . Yoffa . Ociramma . Ociramma . Handsome . Arrowgance . Mizuxe . Boaconic . Sophisticated . Ill-fated . Spourmo . Chubby . Hioffpo . Probable . Singlewave	1
+1474	Take a look at this video on YouTube:	1
+1475	if u love rihanna subscribe me	1
+1476	good music	0
+1477	Love your songs<br />Supper cool<br />	0
+1478	Check Out The New Hot Video By Dante B Called Riled Up	1
+1479	Hey? Everyone Please take a moment to read this. I work my ass off hoping to make it into the music industry but its hard to? get people to hear you when you don’t have money for advertisements and exposure I am a Young Artist who has dreams and goals like everybody else Please take 30 seconds and visit my channel You don’t have to like? me, just give me a chance to prove my talent Also, please take just 1 second of your life and Thumb this comment up it would be Super Helpful.	1
+1480	Best song	0
+1481	HUH HYUCK HYUCK IM SPECIAL WHO&#39;S WATCHING THIS IN 2015 IM FROM AUSTRALIA OR SOMETHING GIVE ME ATTENTION PLEASE IM JUST A RAPPER WITH A DREAM IM GONNA SHARE THIS ON GOOGLE PLUS BECAUSE IM SO COOL.	1
+1482	it is wonderful	0
+1483	2015 and more....	0
+1484	The boyfriend was Charlie from the TV show LOST 	0
+1485	no where near one of eminems actual best songs, real fans know what im talking about: Untitled, cold wind blows, welcome 2 hell, elevator, business, wtp, almost famous, 25 to life, rock bottom, no apologies, same song and dance, without me, way i am, toy soldiers, mosh and insane &lt; songs off the top of my head that&#39;s better	0
+1486	Check out this video on YouTube:	1
+1487	I know that maybe no one will read this but PLEASE TYPE IN &quot;deazy99&quot; I&#39;m a rapper with a dream. I know you must see like millions of those on here everyday but please check out my music and subscribe if you&#39;d like thank you, i would love nothing more than to have a decent following on youtube from people if anyone reading this could give it a &quot;THUMBS UP&quot; because what some might see as just a simple button press could make my dream come true..thank you again for your time &amp; may god bless you	1
+1488	Charlie from LOST?	0
+1489	Check out this video on YouTube:	1
+1490	┏━━━┓┏┓╋┏┓┏━━━┓┏━━━┓┏┓╋╋┏┓  ┃┏━┓┃┃┃╋┃┃┃┏━┓┃┗┓┏┓┃┃┗┓┏┛┃  ┃┗━━┓┃┗━┛┃┃┃╋┃┃╋┃┃┃┃┗┓┗┛┏  ┗━━┓┃┃┏━┓┃┃┗━┛┃╋┃┃┃┃╋┗┓┏┛  ┃┗━┛┃┃┃╋┃┃┃┏━┓┃┏┛┗┛┃╋╋┃┃  ┗━━━┛┗┛╋┗┛┗┛╋┗┛┗━━━┛╋╋┗┛ CHECK MY VIDEOS AND SUBSCRIBE AND LIKE PLZZ	1
+1491	hey guys i know its annoying getting spammed sorry bout that but please take a moment to check out my channel  IM A RAPPER with DECENT skills. i want to share my music with more people so take a listen   THUMBS UP SO MORE CAN SEE THIS	1
+1492	MEGAN FOX AND EMINEM TOGETHER IN A VIDEO  DOESNT GET BETTER THAN THAT	0
+1493	Check out this video on YouTube: <a href="http://www.youtube.com/user/jlimvuth">http://www.youtube.com/user/jlimvuth</a> ... Eminem ft Rihanna - Love the way you lie	1
+1494	He is good boy!!!<br />I am krean I like to eminem~!~	0
+1495	How is this the most watched Eminem video, it beats not afraid? 	0
+1496	Cool	0
+1497	Hi I&#39;m lil m !!! Check out love the way you lie!!!! My live performance and many others,,, videos and my own lyrics!!!!! Thanks	1
+1498	Check out this video on YouTube:	1
+1499	Eminem rap can be easy for a boy but I can do that	0
+1500	Fruits and vegetables give you longer lasting energy for weight loss.  Check out youtube.com/user/36loseweight.	1
+1501	this is the 2nd most best song when im gone by m&amp;m	0
+1502	I always end up coming back to this song<br />	0
+1503	Rihanna and Eminem together are unstoppable.	0
+1504	Hello. İ am from Azerbaijan<br />	0
+1505	You can not hate eminem and nirvana...trust me	0
+1506	Hey guys I&#39;m 87 cypher im 11 years old and Rap is my life I recently made my second album desire ep . please take a moment to check out my album on YouTube thank you very much for reading every like comment and subscription counts	1
+1507	Come and check out my music!Im spamming on loads of videos but its just to try and get my music across to new people	1
+1508	I like it	0
+1509	I love the way you lie	0
+1510	••••►►My name is George and let me tell u EMINEM is my idol my inspiration, I&#39;ve listen to him growing up, I never thou I would love rap this much But I thank him for helping me find my dream. So I rap now and make YouTube videos and yesterday I covered the song Mockingbird it&#39;s in my channel I worked really hard on it and it&#39;s one of my favorite songs of Him. So please go check it out and subscribe it would mean the world to me n sorry for the spam ◄◄••­•• don&#39;t hate I&#39;m not Eminem	1
+1511	CHECK OUT THE NEW REMIX !!!<br />CLICK CLICK !!	1
+1512	You guys should check out this EXTRAORDINARY website called FIREPA.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at FIREPA.COM !   Visit FIREPA.COM and check it out!   Lake   . Chillpal . Sturdy . Astauand . Johackle . Chorenn . Ethosien . Changeable . Public . Noxu . Ploosnar . Looplab . Hoppler . Delicious . False . Scitenoa . Locobot . Heartbreaking . Thirsty . Reminiscent	1
+1513	Love the way you lie II is nicer in my opinion. :D	0
+1514	I know that maybe no one will read this but PLEASE TYPE IN &quot;deazy99&quot; I&#39;m a rapper with a dream. I know you must see like millions of those on here everyday but please check out my music and subscribe if you&#39;d like thank you, i would love nothing more than to have a decent following on youtube from people if anyone reading this could give it a &quot;THUMBS UP&quot; because what some might see as just a simple button press could make my dream come true..thank you again for your time &amp; may god bless you	1
+1515	Check out my remix to Tyga&#39;s - Molly ft. Whiz Khalifa Peace and love	1
+1516	Hey guys im a 17yr old rapper trying to get exposure... I live in belgium where NO ONE speaks english so i have to resort to this gay SPAM...  Check out my 2 latest tracks as they are probably my best.. Audio isnt the best but im gonna invest in some real equipment for my next track..  Please Thumbs this up so others can see.. or hey dont just check me out yourself and leave a response and a like :D  Thanks in advance, you guys will be part of making my dream come TRUE   -Notorious Niko 	1
+1517	check out my playlist	1
+1518	Check out this video on YouTube:	1
+1519	My favorite song 💗💗💗💗	0
+1520	857.482.940 views AWESOME !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!	0
+1521	please read this please! i am a country singer who is trying to become someone. so if you would like my videos please. just type... boys round here by country girl (only on computer) wasting all these tears e.t. please subscribe and like	1
+1522	best rap ever	0
+1523	5 years and i still dont get the music video help someone?	0
+1524	Simply rap god	0
+1525	EMINEM&lt;3 <br />the best rapper ever&lt;3	0
+1526	love	0
+1527	I know that maybe no one will read this but PLEASE TYPE IN &quot;deazy99&quot; I&#39;m a rapper with a dream. I know you must see like millions of those on here everyday but please check out my music and subscribe if you&#39;d like thank you, i would love nothing more than to have a decent following on youtube from people if anyone reading this could give it a &quot;THUMBS UP&quot; because what some might see as just a simple button press could make my dream come true..thank you again for your time &amp; may god bless you	1
+1528	ayyy can u guys please check out my rap video im 16 n im juss tryna get some love please chrck it out an thank u	1
+1529	Hey guys ready for more 87 Cyphers back check out my video on YouTube. NEW ALBUM IS OUT CHECK IT OUT. MORE MUSIC TOMORROW THANKS FOR READING	1
+1530	Amazing	0
+1531	Check out this video on YouTube:	1
+1532	adam b beats check out my page	1
+1533	DO YOU KNOW HOW SEAN KINGSTON GOT FAMOUS WHY DON&#39;T YOU LOOK IT UP KID BEFORE YOUR SO HARD ON YOURSELF!!  IF YOU HIT ME UP WITH A MESSAGE IN MY INBOX AND SUBSCRIBE I WILL CHECK OUT YOUR CHANNEL....SOUNDS FAIR TO ME.	1
+1534	still listening in 2015	0
+1535	so spousal abusue cool that&#39;s great	0
+1536	like this comment then type 1337	1
+1537	COFFEE ! LOVERS ! PLEASE ! READ ! Check out a song I wrote and sing on You Tube called COFFEE LOVA.Type COFFEE LOVA like I spell it while your already on You Tube hit enter.Then look for video titled COFFEE LOVA hit enter and BLAST ! OFF !	1
+1538	For all you ladies out there......  Check out this link!  You&#39;ll find the hottest hairstyles and the latest trends for women!  Go to this site and you&#39;ll upgrade your hairstyles and fashion senses to a higher level!  Don&#39;t get left behind!     ---&gt;   goo.gl\BxrOSR	1
+1539	youtube.com/watch?v=2ASFn9ShgHk&amp;feature=youtu.be  please check out my song. looking for feedback, and supporters.	1
+1540	I cried this song bringing back some hard memories	0
+1541	Fire..	0
+1542	hey its M.E.S here I&#39;m a young up and coming rapper and i wanna get my music heard i know spam wont get me fame. but at the moment i got no way of getting a little attention so please do me a favour and check out my channel and drop a sub if you enjoy yourself. im just getting started so i really appreciate those who take time to leave constructive criticism i already got 200 subscribers and 4000 views on my first vid ive been told i have potential	1
+1543	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!  Why does the wood photograph the husky breath? When does the act retain the delightful system? The rhythm fabricates the scintillating harbor.	1
+1544	do you want to make some easy money? check out my page tvcmcadavid.weebly . com dont miss out on this opportunity	1
+1545	hey its M.E.S here I&#39;m a young up and coming rapper and i wanna get my music heard i know spam wont get me fame. but at the moment i got no way of getting a little attention so please do me a favour and check out my channel and drop a sub if you enjoy yourself. im just getting started so i really appreciate those who take time to leave constructive criticism i already got 200 subscribers and 4000 views on my first vid ive been told i have potential	1
+1546	Just gonna stand there and hear me cry ..	0
+1547	If you are a person that loves real music you should listen to &quot;Cruz Supat&quot;<br />He is awesome as fuck!!! Just as eminem used to be.	0
+1548	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!   Lake   . Waratel . Misty . Exciting . Swoquix . Acaer . Chillpal . Tupacase . Arrowgance . Lively . Hard . Idiotic . Bored . Cool . Ablaze . Crabby . Aloidia . Cheilith . Feandra . Useless . Ploosnar	1
+1549	I love it and my mom to	0
+1550	sorry for the spam yall I know it’s annoying. But if you can spare a min please check out the new track on my channel i&#39;m a upcoming uk rapper.please come check out my songs u might like em. If not no worries I’m sorry for wastin your time. Even thumbs up to get more noticed will really help. peace yall	1
+1551	Check out this video on YouTube:	1
+1552	3 yrs ago I had a health scare but thankfully I’m okay. I realized I wasn’t living life to the fullest.  Now I’m on a mission to do EVERYTHING I’ve always wanted to do. If you found out you were going to die tomorrow would you be happy with what you’ve accomplished or would you regret not doing certain things? Sorry for spamming I’m just trying to motivate people to do the things they’ve always wanted to. If you’re bored come see what I’ve done so far! Almost 1000 subscribers and I just started!	1
+1553	Check out this playlist on YouTube: 	1
+1554	*****PLEASE READ*****  Hey everyone! I&#39;m a 19 year old student  who loves to sing.  I record and upload covers on my channel. I would love if you could check them out.  Just give me a chance. You don&#39;t have to thumbs up or subscribe (But you can if you like what your hear)  Just listen. I&#39;d really appreciate.  ~THANK YOU for your time. ~	1
+1555	my sister just received over 6,500 new <a rel="nofollow" class="ot-hashtag" href="https://plus.google.com/s/%23active">#active</a> youtube views Right now. The only thing she used was pimpmyviews. com	1
+1556	Check out this video on YouTube:	1
+1557	if u love rihanna subscribe me	1
+1558	Check out my channel for funny skits! Thanks!	1
+1559	HEY GUYS!!! ❤❤❤❤❤❤❤  BEFORE YOU IGNORE ME, PLEASE, GIVE ME A CHANCE!  My name is Yuliya, I make COVERS. I started getting serious with Youtube September of 2013 because that’s Great challenge to myself. MY DREAM was to become a singer. Youtube helps me keep up! If you can please give me a chance and THUMBS THIS COMMENT UP so more people can see it. I swear I&#39;ll appreciate it. SUBSCRIBE PLEASE!!!  ❤❤❤ LOVE YOU!!! ❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤ XOXO ❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤❤ 	1
+1560	subscribe to my channel  /watch?v=NxK32i0HkDs	1
+1561	Like eminen	0
+1562	Share this video.. This song can beat PSY - Gangnam Style!	1
+1563	PLEASE CHECK OUT MY VIDEO CALLED &quot;WE LOVE MIND MASTER IT&quot; THANKS	1
+1564	I don&#39;t understand this song, I have had the passion for women feel like im losing my mind but i don&#39;t understand the ideas of loving the the way someone lies.....	0
+1565	I like the music...but is anyone listening to the lyrics?	0
+1566	  Check out my SEXY VIDEO :*	1
+1567	Check out my channel im 15 year old rapper!	1
+1568	one of the BEST SONGS in music history	0
+1569	CHECK OUT Eminem - Rap God LYRIC VIDEO	1
+1570	Haha, I can&#39;t believe how many kids listen to Eminem.  You kids know that he does drugs and ends high in jail. I wonder why he never ended OD. 	0
+1571	Love the way you lie - Driveshaft	0
+1572	Anyone else notice that Megan Fox is in this video?	0
+1573	#2015 FUCK YEAH	0
+1574	is that Megan fox	0
+1575	subcribe to us an we will subscribe back	1
+1576	Check out this video on YouTube:but I&#39;m not Rhinnah	1
+1577	Listen...Check out Andrew Guasch - Crazy, Sick Flow...I&#39;m dope....that&#39;s all there is too it. If you like it Subscribe, if not, Ill be with Aftermath of TDE Soon enough.  One Love,  Peace.	1
+1578	best song	0
+1579	hot,hot	0
+1580	Check out our Channel for nice Beats!!	1
+1581	Check out my mummy chanel!	1
+1582	The rap: cool     Rihanna: STTUUPID	0
+1583	I hope everyone is in good spirits I&#39;m a hard working student who&#39;s also a passionate singer I look foward to the day when I can make my own music to share But for now I&#39;ve just been doing covers. Check out my channel, I&#39;ve done Covers of Miley Cyrus, Imagine Dragons, Lana Del Rey, Drake, Macklemore, Pink and countless others.  Subscribe only if you want to. My goal isn&#39;t to become famous but to  inspire FYI this isn&#39;t spamming, everyone has a right to freedom of speech. Thanks 	1
+1584	Lil m !!!!! Check hi out!!!!! Does live the way you lie and many more ! Check it out!!! And subscribe	1
+1585	Please check out my youtube channel! Just uploaded my first youtube video please check it out, you will not regret it!	1
diff --git a/data/youtube/valid.tsv b/data/youtube/valid.tsv
new file mode 100644
index 0000000000000000000000000000000000000000..3530f659d828714885b9f1ccaaed0b7045f84970
--- /dev/null
+++ b/data/youtube/valid.tsv
@@ -0,0 +1,121 @@
+index	sentence1	label
+0	860,000,000 lets make it first female to reach one billion!! Share it and replay it! 	0
+1	Waka waka eh eh	0
+2	You guys should check out this EXTRAORDINARY website called ZONEPA.COM . You can make money online and start working from home today as I am! I am making over $3,000+ per month at ZONEPA.COM ! Visit Zonepa.com and check it out! How does the mother approve the axiomatic insurance? The fear appoints the roll. When does the space prepare the historical shame?	1
+3	Check out  these Irish guys cover  of Avicii&#39;s  Wake Me Up!  Just search...  &quot;wake me up Fiddle Me Silly&quot; Worth a listen  for the gorgeous fiddle player!	1
+4	if you want to win money at hopme click here <a href="https://www.paidverts.com/ref/sihaam01">https://www.paidverts.com/ref/sihaam01</a> it&#39;s work 100/100	1
+5	Love it	0
+6	Hey Youtubers and All Music lover&#39;s, Guess most of you all skip these comments, but for you who is still reading this, thanks ! I dont have any money for advertisiments, no chance of getting heard, nothing. All that&#39;s left is spam, sorry. Im 17, Rapper/Singer from Estonia. Please listen my new cover on my account. You wont regret it. Give me just a chance, please. Take half a second of your life and thumb this comment up. It will maybe change my life, for real. Thank you Wafence 	1
+7	**CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE****	1
+8	/watch?v=Dtqcftr1Fac JUSTIEN BIEBER CAR 2013. LIKE&amp;SUBSCRIBE	1
+9	WOW muslims are really egoistic..... 23% of the World population and not in this video or donating 1 dollar to the poor ones in Africa :( shame on those terrorist muslims	1
+10	Recommend:  Apple iPad 4th Gen 32GB Unlocked Wi-Fi+4G 9.7in White Price:$390  Apple iPhone 5 (Latest Model) - 32GB - Black Price:$385  Samsung Galaxy S4 S IV 4 with 16GB New White Price:$360  Sony 60-inch 3D LED HDTV Price:$510  All-in-One PCs: Apple MacBook Pro: Apple MacBook Air Price:$320  Camera :Nikon D90 SLR Camera /18-55mm /55-200mm 32GB  Price:$390   Ultrabooks: SONY VAIO Pro 13 Intel Core i5 4GB 128GB Price:$515  +++++++++++++++    Purchase online Website is:  Taaee.com	1
+11	Whose watching this in 2015. If so hi-5	0
+12	Love this song !!!!!!	0
+13	Her voice sounds weird and plus she&#39;s cute for a blonde	0
+14	One of the best song of all the time	0
+15	Echa un vistazo a la remezcla! / Check out the remix!  MILEY CYRUS - WRECKING BALL (THE HOUSE OF EDM REMIX)  MILEY CYRUS - WRECKING BALL (VINCENT T. REMIX)  ...click, enlace, suscríbase! / ...click, link, subscribe!	1
+16	Cutie girl and beautiful song	0
+17	Check Out Daneja Good Girl	1
+18	You guys should check out this EXTRAORDINARY website called MONEYGQ.COM .   You can make money online and start working from home today as I am!   I am making over $3,000+ per month at MONEYGQ.COM !   Visit MONEYGQ.COM and check it out!  Why does the fragile swim enlist the person? How does the ice audit the frequent son? The fantastic chance describes the rate.	1
+19	i love this song thumsb up to you	0
+20	Hey Music Fans I really appreciate all of you who take time to read this, and check my music out! I&#39;m just a 15 year old boy DREAMING of being a successful MUSICIAN in the music world. I do lots of covers, and piano covers. But I dont have money to advertise. A simple thumbs up to my comment, a comment on my videos or a SUBSCRIPTION would be a step forward! It will only be a few seconds of your life that u won&#39;t regret!!! Thank u to all the people who just give me a chance it means a lot! :)	1
+21	Hi.Check out and share our songs.	1
+22	SUBSCRIBE ME AND I REQUITE	1
+23	Shakira u are so wiredo	0
+24	Fuck it was the best ever 0687119038 nummber of patrik kluivert his son share !	1
+25	Check out this video on YouTube:	1
+26	Shakira	0
+27	Hello everyone :) I know most of you probably pass up these kind of comments, but for those who are still reading this, thanks! I don’t have any money for advertisements, no chance of getting heard, nothing. I live in such a small town... If this comes off as spam, sorry. I’m an instrumental songwriter from Columbus, Mississippi. Please go to my channel and check out my original music. It would be highly appreciated if you thumbs up this comment so my music can be heard! Thank you, Adam Whitney 	1
+28	check out our bands page on youtube killtheclockhd - check out some of our original songs including &quot;your disguise&quot;	1
+29	Hey guys whats up? I found this app that lets you get free gift card vouchers like psn cards,X-box live cards and even amazon gift cards. For free! All you have to do is  simply just download the app from the app store. It is called juno wallet. All you have to do is just sign up for the app and then complete a few surveys or just download some other free apps and you get money like 10 cents. Also, if you type in the code IM2458444. You will also start off with $0.25 free!! 	1
+30	She is perfect	0
+31	Hi. Check out and share our songs.	1
+32	wanna earn money online without investment.....just visit this link .....therglove.blogspot.in/2013/08/blog-post_10.html	1
+33	Me and my aunt love this song!!!!!	0
+34	I love this song and expect the World Cup .	0
+35	Stop Wasting Up Your Time and  Get Paid To Mess Around On Facebook And Twitter!  GET PAID UPTO $25 to $35 AN HOUR... Only at 4NetJobs.com  Work from the Comfort of your Home... We are Currently Hiring People from all Over the World,  For a Wide Range of Social Media Jobs on Sites such as Facebook,Twitter and YouTube.  You don&#39;t Need any Prior Skills or Experience and You can Begin Work Immediately!  You Can Easily Make $4000 to $5000+ Monthly Income…Only at 4NetJobs.com	1
+36	the song is sad	0
+37	SUBSCRIBE MY CHANNEL	1
+38	1 753 682 421 GANGNAM STYLE ^^	1
+39	Best world cup offical song	0
+40	I remember this :D	0
+41	**CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE******CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE**** **CHECK OUT MY NEW MIXTAPE*** ***CHECK OUT MY NEW MIXTAPE****	1
+42	Love this song!!!	0
+43	Hi, nice song Shakira! (Sorry for bad Brazilian)	0
+44	Part 5. Comforter of the afflicted, pray for us Help of Christians, pray for us Queen of Angels, pray for us Queen of Patriarchs, pray for us Queen of Prophets, pray for us Queen of Apostles, pray for us Queen of Martyrs, pray for us Queen of Confessors, pray for us Queen of Virgins, pray for us Queen of all Saints, pray for us Queen conceived without original sin, pray for us Queen of the most holy Rosary, pray for us Queen of the family, pray for us Queen of peace, pray for us 	1
+45	Nice song	0
+46	Hey Music Fans I really appreciate any of you who will take the time to read this, and check my music out! I&#39;m just a 15 year old boy DREAMING of being a successful MUSICIAN in the music world. I do lots of covers, and piano covers. But I don&#39;t have money to advertise. A simple thumbs up to my comment, a comment on my videos or a SUBSCRIPTION would be a step forward! It will only be a few seconds of your life that you won&#39;t regret!!! Thank u to all the people who just give me a chance! :) 	1
+47	like!!!!!	0
+48	subscribe to my pagee please.	1
+49	shakira is the best!	0
+50	:D subscribe to me for daily vines	1
+51	I WILL NEVER FORGET THIS SONG IN MY LIFE LIKE THIS COMMENT OF YOUR HEARING THIS SONG FOR LIKE A YEAR!!!!!	1
+52	I felt old when I realized that this song was 5 years old...	0
+53	GREAT!!!	0
+54	I Know Y&#39;all Can Check Out Amy Music Because Yall Got Me Go To Da Channel	1
+55	Shakira is different :) She is so happy all the time and she is spending for Africa :)<br />She can dance , sing and she speaks 4 languages	0
+56	I really am madly in love with this woman!!	0
+57	I love you  ;p	0
+58	Hi. Check out and share our songs.	1
+59	Where are Shakifans?	0
+60	Hey Music Fans I really appreciate any of you who will take the time to read this, and check my music out! I&#39;m just a 15 year old boy DREAMING of being a successful MUSICIAN in the music world. I do lots of covers, and piano covers. But I don&#39;t have money to advertise. A simple thumbs up to my comment, a comment on my videos or a SUBSCRIPTION would be a step forward! It will only be a few seconds of your life that you won&#39;t regret!!! Thank u to all the people who just give me a chance! :) 	1
+61	Check out this playlist on YouTube<br />	1
+62	Wow	0
+63	She is good	0
+64	Being paid to respond to fast paid surveys from home has enabled me to give up working and make more than 4500 bucks monthly.  To read more go to this web site bit.ly\1bSefQe	1
+65	I love	0
+66	Check out my covers please!	1
+67	So underrated better<br />Than  Katy perry :/ but in not saying Katy is bad but she has no meaning to her songs. Shakira has<br />Meaningful songs these are the songs which bring memories and vibe&#39;s. I will miss it.   <br />   	0
+68	Every time I hear this song, I think about Iniesta&#39;s goal against the Netherlands...	0
+69	Meet The Richest Online Marketer  NOW CLICK : bit.ly/make-money-without-adroid	1
+70	i like it :)	0
+71	Check out my channel :)	1
+72	Like	0
+73	Cool song 	0
+74	Shakira is perfect	0
+75	we are 14 year old boys we are trying to enter in the music industry Maybe we are just a dreamers but we are not the? only one. Please give us one chance to prove ourself to you. we know these things are annoying as hell we are sorry? about that. But we do anything to get music heard. Please visit our channel subscribe if you like and thumb this comment up so everyone can see we have made a song called: I Wanna Play A Game please take a look on our channel Thankss!!!	1
+76	BEST SONG EVER X3333333333	0
+77	********OMG Facebook is OLD! Check out  -----------------&gt; swagFriends com Make thousands of cool new friends everyday! Join the movement!	1
+78	:)	0
+79	  Perhaps you have seen the newest Miley Cyrus SECRET video ?   She&#39;s sucking an old man&#39;s cock ,  If you wish to see her , check out the celebrity website beneath :   miley-secret-video.co.uk 	1
+80	goood	0
+81	Check out this video on YouTube:	1
+82	Hey guys love this but check out this girl name cause she knows how to dance so search up ej ba dancing.	1
+83	waka waka	0
+84	Message :   GTA V  $20  FIFA 14 $15  PS4  $200  Galaxy S4 mini $250  Ipad 4   $200  visit the site hh.nl	1
+85	i love her	0
+86	Pleas subscribe GamezZ MTA my channel<br />	1
+87	I love this song &amp;I love shakira&lt;333♡♡♡♡♡♡♡♡♡♡♡《33☆☆☆	0
+88	she is sooooo beautiful!	0
+89	Wanderfol is love or miusic	0
+90	visit &quot; ww estiloproduction com &quot; best website to make money	1
+91	wow	0
+92	Waka waka 	0
+93	Amazing song	0
+94	I love dis song!! 3	0
+95	coby this USL and past :<br /><a href="http://adf.ly">http://adf.ly</a> /1HmVtX<br />delete space after y	1
+96	This makes me miss the world cup	0
+97	ILove shakira 	0
+98	adf.ly / KlD3Y	1
+99	I believe that soccer promotes terrorism. Bad choice, Shakira.	0
+100	i watched this because of the large amount of views and now i am convinced it is because that girl is so hot	0
+101	Still amazing this song!! I did a French Kpop Parody. Please come on my Channel to watch it. Share, suscribe, comment Thanks	1
+102	Nice love itttttttt wurkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk	0
+103	Please go in and see our Channel and subscribe :-). It would be Nice mate.	1
+104	BEST SONG! GO SHAKI :D	0
+105	god she is so sexy! drives me crazy!	0
+106	5 years ago damn 	0
+107	CHECK OUT DANEJA GOOD GIRL	1
+108	Earn money for being online with 0 efforts!    bit.ly\14gKvDo	1
+109	Love this song	0
+110	PLEASE CHECK OUT MY VIDEO CALLED &quot;WE LOVE MIND MASTER IT&quot; THANKS	1
+111	To help shakira become the first female to hit 1billiom views, I&#39;ve decided to watch this Video at least one a day. Everyday! Shakifans we r so close just hit the replay button !!!	0
+112	There are beautiful songs please subscribe	1
+113	I love this song for two reasons: 1.it is about Africa 2.i was born in beautiful south Africa	0
+114	http://www.ebay.com/usr/shoecollector314	1
+115	It makes me happy instantly, and makes me forgot everything bad happening!	0
+116	I love song 	0
+117	5 years soon!	0
+118	Lol I love this song	0
+119	Thumbs up if your watching in 2015	1
diff --git a/dst/agra_dst.py b/dst/agra_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c69a8e2bada252333c885afdbef1a1f2cbd90a
--- /dev/null
+++ b/dst/agra_dst.py
@@ -0,0 +1,99 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of AGRA
+# (arXiv:2306.04502)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import autograd_hacks
+import os
+import torch
+import sys
+
+import numpy as np
+
+from torch.utils.data import (DataLoader, WeightedRandomSampler)
+
+from agra import (AGRA)
+from utils_gpu import (to_device)
+from dst.utils_storm_dst import (batch_to_dict)
+
+
+class AGRA(AGRA):
+    # Implementation is limited to the slot gates.
+    def _get_comp_grads(self):
+        comp_grads = {params[0]: params[1].grad.reshape(-1).detach().clone().cpu()
+                      for params in self.model.named_parameters()
+                      if hasattr(params[1], 'grad1') and 'bias' not in params[0]}
+        if 'comp_grads' not in self.stats:
+            self.stats['comp_grads'] = {}
+        mean_comp_grads = {}
+        for slot in self.model.slot_list:
+            if slot not in self.stats['comp_grads']:
+                self.stats['comp_grads'][slot] = []
+            self.stats['comp_grads'][slot].append(comp_grads['class_' + slot + '.weight'].numpy())
+            self.stats['comp_grads'][slot] = self.stats['comp_grads'][slot][-1 * self.window_size:]
+            mean_comp_grads[slot] = np.mean(self.stats['comp_grads'][slot], 0)
+        return mean_comp_grads
+
+
+    def build_dataloader(self, batch_size):
+        comp_sampler = WeightedRandomSampler(self.agra_weights, len(self.dataset))
+        self.comp_dataloader = DataLoader(
+            self.dataset, sampler=comp_sampler, batch_size=batch_size, drop_last=True)
+        self.comp_dataloader = iter(self.comp_dataloader)
+
+
+    # Implementation is limited to the slot gates.
+    def agra_step(self, batch):
+        batch_size = batch['input_ids'].size(0)
+
+        # Get comparison gradients
+        comp_batch = next(self.comp_dataloader)
+        autograd_hacks.clear_grad1(self.model)
+        for slot in self.model.slot_list:
+            getattr(self.model, "class_" + slot).weight.retain_grad() # required for _get_comp_grads()
+        comp_batch = to_device(batch_to_dict(comp_batch), self.device)
+        comp_outputs = self.model(**comp_batch, suppress_dropout_passes=True)
+        comp_loss = comp_outputs[0][0]
+        comp_loss.backward()
+        autograd_hacks.compute_grad1(self.model, loss_type="sum", layer_groups=["class_"])
+        comp_grads = self._get_comp_grads()
+
+        for slot in self.model.slot_list:
+            del getattr(self.model, "class_" + slot).weight.grad
+        autograd_hacks.clear_grad1(self.model)
+
+        # Get sample gradients
+        outputs = self.model(**batch)
+        sample_loss = outputs[0][0]
+        sample_loss.backward()
+        autograd_hacks.compute_grad1(self.model, loss_type="sum", layer_groups=["class_"])
+        grads = {params[0]: params[1].grad1.detach().clone().cpu() for params in self.model.named_parameters() if
+                 hasattr(params[1], 'grad1') and 'bias' not in params[0]}
+
+        # Get gradient scores
+        grad_scores = {}
+        for slot in self.model.slot_list:
+            grad_scores[slot] = np.zeros(batch_size)
+            for l_itr in range(batch_size):
+                sample_grads = grads["class_" + slot + ".weight"][l_itr].reshape(-1).numpy()
+                grad_scores[slot][l_itr] = np.sum(sample_grads * comp_grads[slot]) / ((np.linalg.norm(sample_grads) * np.linalg.norm(comp_grads[slot])) + 1e-8)
+            grad_scores[slot] = torch.tensor(grad_scores[slot])
+
+        autograd_hacks.clear_grad1(self.model)
+
+        return grad_scores
+    
diff --git a/dst/modeling_dst.py b/dst/modeling_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5952e84d32337a9579cc8863f834208722430da
--- /dev/null
+++ b/dst/modeling_dst.py
@@ -0,0 +1,282 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of TripPy
+# (arXiv:2005.02877)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers import (BertModel, BertPreTrainedModel,
+                          RobertaModel, RobertaPreTrainedModel,
+                          ElectraModel, ElectraPreTrainedModel)
+
+from modeling import (PARENT_CLASSES,
+                      MODEL_CLASSES,
+                      StraightThroughEstimator,
+                      ElectraPooler)
+
+
+def TransformerForDST(parent_name):
+    if parent_name not in PARENT_CLASSES:
+        raise ValueError("Unknown model %s" % (parent_name))
+
+    class TransformerForDST(PARENT_CLASSES[parent_name]):
+        def __init__(self, config):
+            assert config.model_type in PARENT_CLASSES
+            assert self.__class__.__bases__[0] in MODEL_CLASSES
+            super(TransformerForDST, self).__init__(config)
+            self.model_type = config.model_type
+            self.slot_list = config.dst_slot_list
+            self.class_types = config.dst_class_types
+            self.class_labels = config.dst_class_labels
+            self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
+            self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
+            self.class_aux_feats_inform = config.dst_class_aux_feats_inform
+            self.class_aux_feats_ds = config.dst_class_aux_feats_ds
+            self.class_loss_ratio = config.dst_class_loss_ratio
+            self.dropout_rounds = config.dropout_rounds
+            self.no_class_separation = config.no_class_separation
+            self.rescaler_featnum = config.rescaler_featnum
+            self.rescaler_binary = config.rescaler_binary
+            self.rescaler_binary_threshold = config.rescaler_binary_threshold
+            self.config = config
+
+            # Only use refer loss if refer class is present in dataset.
+            if 'refer' in self.class_types:
+                self.refer_index = self.class_types.index('refer')
+            else:
+                self.refer_index = -1
+
+            self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config))
+            if self.model_type == "electra":
+                self.pooler = ElectraPooler(config)
+            
+            self.dropout = nn.Dropout(config.dst_dropout_rate)
+            self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
+
+            if self.class_aux_feats_inform:
+                self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+            if self.class_aux_feats_ds:
+                self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
+
+            aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
+
+            for slot in self.slot_list:
+                self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
+                self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
+                self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
+
+            if not self.no_class_separation:
+                for c in self.class_types:
+                    self.add_module("rescaler_" + c, nn.Sequential(
+                        nn.Linear(self.rescaler_featnum, self.rescaler_featnum),
+                        nn.BatchNorm1d(self.rescaler_featnum, affine=False),
+                        nn.ReLU(),
+                        nn.Linear(self.rescaler_featnum, 2),
+                        nn.BatchNorm1d(2, affine=False),
+                        nn.Softmax(dim=1),
+                    ))
+            else:
+                self.rescaler = nn.Sequential(
+                    nn.Linear(self.rescaler_featnum, self.rescaler_featnum),
+                    nn.BatchNorm1d(self.rescaler_featnum, affine=False),
+                    nn.ReLU(),
+                    nn.Linear(self.rescaler_featnum, 2),
+                    nn.BatchNorm1d(2, affine=False),
+                    nn.Softmax(dim=1),
+                )
+            if self.rescaler_binary:
+                if not self.no_class_separation:
+                    for c in self.class_types:
+                        getattr(self, "rescaler_" + c).add_module(str(len(getattr(self, "rescaler_" + c)) + 1), StraightThroughEstimator(config.rescaler_binary_threshold))
+                else:
+                    self.rescaler.add_module(str(len(self.rescaler) + 1), StraightThroughEstimator(config.rescaler_binary_threshold))
+
+            # Initialize weights and apply final processing
+            self.init_weights()
+
+        def forward(self,
+                    input_ids = None,
+                    input_mask = None,
+                    segment_ids = None,
+                    position_ids = None,
+                    head_mask = None,
+                    start_pos = None,
+                    end_pos = None,
+                    inform_slot_id = None,
+                    refer_id = None,
+                    class_label_id = None,
+                    ids = None,
+                    diag_state = None,
+                    suppress_dropout_passes = False,
+                    feats = None,
+                    mode = "default"):
+
+            # --------------
+            # Rescaler model
+            # --------------
+
+            if mode == "rescaler":
+                if not self.no_class_separation:
+                    logits = tuple([getattr(self, "rescaler_" + c)(feats) for c in self.class_types])
+                else:
+                    logits = self.rescaler(feats)
+                return (logits,)
+
+            # --------------
+            # Task model
+            # --------------
+
+            if inform_slot_id is not None:
+                inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
+            if diag_state is not None:
+                diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
+
+            dropout_rounds = 1
+            if not suppress_dropout_passes:
+                dropout_rounds += self.dropout_rounds
+
+            dropout_total_losses = []
+            dropout_per_slot_per_example_losses = []
+            dropout_per_slot_class_losses = []
+            dropout_per_slot_class_logits = []
+            dropout_per_slot_token_losses = []
+            dropout_per_slot_start_logits = []
+            dropout_per_slot_end_logits = []
+            dropout_per_slot_refer_losses = []
+            dropout_per_slot_refer_logits = []
+            for i in range(dropout_rounds):
+                if i > 0:
+                    torch.set_grad_enabled(False)
+
+                outputs = getattr(self, self.model_type)(
+                    input_ids,
+                    attention_mask=input_mask,
+                    token_type_ids=segment_ids,
+                    position_ids=position_ids,
+                    head_mask=head_mask
+                )
+
+                sequence_output = outputs[0]
+                if self.model_type == "electra":
+                    pooled_output = self.pooler(sequence_output)
+                else:
+                    pooled_output = outputs[1]
+
+                sequence_output = self.dropout(sequence_output)
+                pooled_output = self.dropout(pooled_output)
+
+                total_loss = 0
+                per_slot_per_example_loss = {}
+                per_slot_class_loss = {}
+                per_slot_class_logits = {}
+                per_slot_token_loss = {}
+                per_slot_start_logits = {}
+                per_slot_end_logits = {}
+                per_slot_refer_loss = {}
+                per_slot_refer_logits = {}
+                for slot in self.slot_list:
+                    if self.class_aux_feats_inform and self.class_aux_feats_ds:
+                        pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
+                    elif self.class_aux_feats_inform:
+                        pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
+                    elif self.class_aux_feats_ds:
+                        pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
+                    else:
+                        pooled_output_aux = pooled_output
+                    class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
+
+                    token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
+                    start_logits, end_logits = token_logits.split(1, dim=-1)
+                    start_logits = start_logits.squeeze(-1)
+                    end_logits = end_logits.squeeze(-1)
+
+                    refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
+
+                    per_slot_class_logits[slot] = class_logits
+                    per_slot_start_logits[slot] = start_logits
+                    per_slot_end_logits[slot] = end_logits
+                    per_slot_refer_logits[slot] = refer_logits
+
+                    # If there are no labels, don't compute loss
+                    if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
+                        # If we are on multi-GPU, split add a dimension
+                        if len(start_pos[slot].size()) > 1:
+                            start_pos[slot] = start_pos[slot].squeeze(-1)
+                        if len(end_pos[slot].size()) > 1:
+                            end_pos[slot] = end_pos[slot].squeeze(-1)
+                        # sometimes the start/end positions are outside our model inputs, we ignore these terms
+                        ignored_index = start_logits.size(1) # This is a single index
+                        start_pos[slot] = start_pos[slot].clamp(0, ignored_index)
+                        end_pos[slot] = end_pos[slot].clamp(0, ignored_index)
+
+                        class_loss_fct = CrossEntropyLoss(reduction='none')
+                        token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
+                        refer_loss_fct = CrossEntropyLoss(reduction='none')
+
+                        start_loss = token_loss_fct(start_logits, start_pos[slot])
+                        end_loss = token_loss_fct(end_logits, end_pos[slot])
+                        token_loss = (start_loss + end_loss) / 2.0
+
+                        token_is_pointable = (start_pos[slot] > 0).float()
+                        if not self.token_loss_for_nonpointable:
+                            token_loss *= token_is_pointable
+
+                        refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
+                        token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
+                        if not self.refer_loss_for_nonpointable:
+                            refer_loss *= token_is_referrable
+
+                        class_loss = class_loss_fct(class_logits, class_label_id[slot])
+
+                        if self.refer_index > -1:
+                            per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
+                        else:
+                            per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
+
+                        total_loss += per_example_loss.sum()
+                        per_slot_per_example_loss[slot] = per_example_loss
+                        per_slot_class_loss[slot] = class_loss
+                        per_slot_token_loss[slot] = token_loss
+                        per_slot_refer_loss[slot] = refer_loss
+
+                dropout_total_losses.append(total_loss)
+                dropout_per_slot_per_example_losses.append(per_slot_per_example_loss)
+                dropout_per_slot_class_losses.append(per_slot_class_loss)
+                dropout_per_slot_class_logits.append(per_slot_class_logits)
+                dropout_per_slot_token_losses.append(per_slot_token_loss)
+                dropout_per_slot_start_logits.append(per_slot_start_logits)
+                dropout_per_slot_end_logits.append(per_slot_end_logits)
+                dropout_per_slot_refer_losses.append(per_slot_refer_loss)
+                dropout_per_slot_refer_logits.append(per_slot_refer_logits)
+
+                torch.set_grad_enabled(True)
+
+            outputs = (dropout_total_losses,
+                       dropout_per_slot_per_example_losses,
+                       dropout_per_slot_class_losses,
+                       dropout_per_slot_class_logits,
+                       dropout_per_slot_token_losses,
+                       dropout_per_slot_start_logits,
+                       dropout_per_slot_end_logits,
+                       dropout_per_slot_refer_losses,
+                       dropout_per_slot_refer_logits,) + outputs[2:]
+
+            return outputs
+
+    return TransformerForDST
diff --git a/dst/utils_storm_dst.py b/dst/utils_storm_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..c16bdf59d21c5f87df3a872a39065f50ac1e005f
--- /dev/null
+++ b/dst/utils_storm_dst.py
@@ -0,0 +1,237 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import torch
+
+import numpy as np
+
+from statistics import (NormalDist)
+
+from utils_storm import (Results, Filter,
+                         gaussian_KL)
+
+logger = logging.getLogger(__name__)
+
+
+def batch_to_dict(batch):
+    assert len(batch) == 10
+    return  {'input_ids':       batch[0],
+             'input_mask':      batch[1], 
+             'segment_ids':     batch[2],
+             'start_pos':       batch[3],
+             'end_pos':         batch[4],
+             'inform_slot_id':  batch[5],
+             'refer_id':        batch[6],
+             'diag_state':      batch[7],
+             'class_label_id':  batch[8],
+             'ids':             batch[9]}
+
+
+class Results(Results):
+    def _restructure(self, data):
+        restructured_data = {slot: [] for slot in data[0]}
+        for d in data:
+            for s in d:
+                restructured_data[s].append(d[s])
+        for s in restructured_data:
+            restructured_data[s] = torch.stack(restructured_data[s])
+        return restructured_data
+
+
+    def update(self, losses, logits, labels):
+        self.examples['losses'] = self._restructure(losses)
+        self.examples['logits'] = self._restructure(logits)
+        self.examples['labels'] = labels
+        self.examples['probs'] = {}
+        self.examples['preds'] = {}
+
+        for slot in self.examples['logits']:
+            self.examples['probs'][slot] = torch.softmax(self.examples['logits'][slot].float(), dim=2)
+            self.examples['preds'][slot] = torch.argmax(self.examples['logits'][slot], dim=2)
+
+        self._update_agreement(labels)
+        self._update_means()
+
+
+    def _update_agreement(self, labels):
+        self.examples['agreement'] = {}
+        for slot in self.examples['preds']:
+            agreement = self.examples['preds'][slot][0] == labels[slot] # [0] is prediction for which we backprop
+            dropout_agreement_cnt = (self.examples['preds'][slot] == labels[slot]).sum(0)
+            dropout_agreement = dropout_agreement_cnt >= self.examples['preds'][slot].size(0) / 2
+            tie = dropout_agreement_cnt == self.examples['preds'][slot].size(0) / 2
+            tie_idx = tie.nonzero(as_tuple=True)[0]
+            dropout_agreement[tie_idx] = agreement[tie_idx]
+            self.examples['agreement'][slot] = dropout_agreement
+
+
+    def _update_means(self):
+        self.examples['losses_means'] = {}
+        self.examples['losses_stds'] = {}
+        for slot in self.examples['losses']:
+            self.examples['losses_means'][slot] = self.examples['losses'][slot].mean(0)
+            self.examples['losses_stds'][slot] = self.examples['losses'][slot].std(0).nan_to_num()
+
+        self.examples['probs_means'] = {}
+        self.examples['probs_stds'] = {}
+        for slot in self.examples['probs']:
+            self.examples['probs_means'][slot] = self.examples['probs'][slot].max(2)[0].mean(0)
+            self.examples['probs_stds'][slot] = self.examples['probs'][slot].max(2)[0].std(0).nan_to_num()
+
+
+class Filter(Filter):
+    def get_stats(self, name, slot):
+        return self.stats[name][slot]
+
+
+    def get_stats_tensor(self, name, slot):
+        return torch.tensor(list(self.stats[name][slot].values()))
+
+
+    def _append_new_stats(self, stats):
+        for e in stats.examples:
+            for slot in stats[e]:
+                if stats[e][slot].dim() == 1:
+                    if e not in self.stats:
+                        self.stats[e] = {}
+                    if slot not in self.stats[e]:
+                        self.stats[e][slot] = {}
+                    for l in range(self.num_labels):
+                        if l not in self.stats[e][slot] or self.window_size == 0:
+                            self.stats[e][slot][l] = torch.tensor([], dtype=stats[e][slot].dtype)
+                        if self.num_labels > 1:
+                            self.stats[e][slot][l] = torch.cat((self.stats[e][slot][l], stats[e][slot][stats['labels'][slot] == l]))
+                        else:
+                            self.stats[e][slot][l] = torch.cat((self.stats[e][slot][l], stats[e][slot]))
+                        if self.window_size > 0:
+                            self.stats[e][slot][l] = self.stats[e][slot][l][-1 * self.window_size * self.batch_size:] # sliding window
+
+
+    def update_batch_stats(self, stats):
+        self._append_new_stats(stats)
+
+        self.stats['loss_means_mean'] = {}
+        self.stats['loss_means_std'] = {}
+        self.stats['loss_stds_mean'] = {}
+        self.stats['loss_stds_std'] = {}
+        self.stats['prob_means_mean'] = {}
+        self.stats['prob_means_std'] = {}
+        self.stats['prob_stds_mean'] = {}
+        self.stats['prob_stds_std'] = {}
+        self.stats['kl'] = {}
+        self.stats['kl_mean'] = {}
+        self.stats['kl_std'] = {}
+        self.stats['ovl'] = {}
+        self.stats['ovl_mean'] = {}
+        self.stats['ovl_std'] = {}
+        if 'grad_scores' in stats.examples:
+            self.stats['grad_scores_mean'] = {}
+            self.stats['grad_scores_std'] = {}
+        for slot in self.stats['agreement']:
+            self.stats['loss_means_mean'][slot] = {}
+            self.stats['loss_means_std'][slot] = {}
+            self.stats['loss_stds_mean'][slot] = {}
+            self.stats['loss_stds_std'][slot] = {}
+            self.stats['prob_means_mean'][slot] = {}
+            self.stats['prob_means_std'][slot] = {}
+            self.stats['prob_stds_mean'][slot] = {}
+            self.stats['prob_stds_std'][slot] = {}
+            self.stats['kl'][slot] = {}
+            self.stats['kl_mean'][slot] = {}
+            self.stats['kl_std'][slot] = {}
+            self.stats['ovl'][slot] = {}
+            self.stats['ovl_mean'][slot] = {}
+            self.stats['ovl_std'][slot] = {}
+            if 'grad_scores' in stats.examples:
+                self.stats['grad_scores_mean'][slot] = {}
+                self.stats['grad_scores_std'][slot] = {}
+            for l in range(self.num_labels):
+                # Get batch loss statistics, separate by sample, then accumulated
+                (self.stats['loss_means_mean'][slot][l],
+                 self.stats['loss_means_std'][slot][l],
+                 self.stats['loss_stds_mean'][slot][l],
+                 self.stats['loss_stds_std'][slot][l]) = self._update_batch_stats(self.stats['losses_means'][slot][l],
+                                                                                  self.stats['losses_stds'][slot][l],
+                                                                                  self.stats['agreement'][slot][l])
+                
+                # Get batch probability statistics, separate by sample, then accumulated
+                (self.stats['prob_means_mean'][slot][l],
+                 self.stats['prob_means_std'][slot][l],
+                 self.stats['prob_stds_mean'][slot][l],
+                 self.stats['prob_stds_std'][slot][l]) = self._update_batch_stats(self.stats['probs_means'][slot][l],
+                                                                                  self.stats['probs_stds'][slot][l],
+                                                                                  self.stats['agreement'][slot][l])
+
+                # Get batch loss distribution KL divergence statistics
+                (self.stats['kl'][slot][l],
+                 self.stats['kl_mean'][slot][l],
+                 self.stats['kl_std'][slot][l]) = self._update_batch_kl(self.stats['losses_means'][slot][l],
+                                                                        self.stats['losses_stds'][slot][l],
+                                                                        self.stats['loss_means_mean'][slot][l],
+                                                                        self.stats['loss_stds_mean'][slot][l],
+                                                                        self.stats['agreement'][slot][l],
+                                                                        mode="kl")
+
+                # Get batch loss distribution overlap statistics
+                (self.stats['ovl'][slot][l],
+                 self.stats['ovl_mean'][slot][l],
+                 self.stats['ovl_std'][slot][l]) = self._update_batch_kl(self.stats['losses_means'][slot][l],
+                                                                         self.stats['losses_stds'][slot][l],
+                                                                         self.stats['loss_means_mean'][slot][l],
+                                                                         self.stats['loss_stds_mean'][slot][l],
+                                                                         self.stats['agreement'][slot][l],
+                                                                         mode="ovl")
+
+                # Get grad statistics
+                if 'grad_scores' in stats.examples:
+                    (self.stats['grad_scores_mean'][slot][l],
+                     self.stats['grad_scores_std'][slot][l]) = self._update_global_batch_stats(self.stats['grad_scores'][slot][l],
+                                                                                      self.stats['agreement'][slot][l])
+
+
+    def get_sample_stats(self, batch, stats, no_cat=False):
+        batch_size = batch['input_ids'].size(0)
+        kls = {slot: [] for slot in stats['agreement']}
+        ovls = {slot: [] for slot in stats['agreement']}
+        for l_itr in range(batch_size):
+            for slot in stats['agreement']:
+                d_cat = int(not no_cat) * (2 - int(stats['agreement'][slot][l_itr])) # 0, 1 or 2 (agree or disagree)
+
+                if self.num_labels > 1:
+                    rlbl = batch['class_label_id'][slot][l_itr].item()
+                else:
+                    rlbl = 0
+
+                d_losses_mean = stats['losses_means'][slot][l_itr].item()
+                d_losses_std = stats['losses_stds'][slot][l_itr].item()
+                d_probs_mean = stats['probs_means'][slot][l_itr].item()
+                d_probs_std = stats['probs_stds'][slot][l_itr].item()
+                if 'grad_scores' in stats.examples:
+                    d_grad_scores = stats['grad_scores'][slot][l_itr].item()
+
+                g_kl_div = gaussian_KL(d_losses_mean,
+                                       d_losses_std,
+                                       self.stats['loss_means_mean'][slot][rlbl][d_cat],
+                                       self.stats['loss_stds_mean'][slot][rlbl][d_cat])
+                ovl = NormalDist(mu=d_losses_mean,
+                                 sigma=max(d_losses_std, self.eps)).overlap(
+                                     NormalDist(mu=self.stats['loss_means_mean'][slot][rlbl][d_cat],
+                                                sigma=max(self.stats['loss_stds_mean'][slot][rlbl][d_cat], self.eps)))
+
+                kls[slot].append(g_kl_div)
+                ovls[slot].append(ovl)
+        return kls, ovls
diff --git a/modeling.py b/modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6493b57dd24e56d59e17592059d52ab4611172c
--- /dev/null
+++ b/modeling.py
@@ -0,0 +1,254 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of TripPy
+# (arXiv:2005.02877)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from torch import nn
+from torch.nn import (CrossEntropyLoss)
+from transformers import (BertModel, BertPreTrainedModel,
+                          RobertaModel, RobertaPreTrainedModel,
+                          ElectraModel, ElectraPreTrainedModel)
+
+PARENT_CLASSES = {
+    'bert': BertPreTrainedModel,
+    'roberta': RobertaPreTrainedModel,
+    'electra': ElectraPreTrainedModel
+}
+
+MODEL_CLASSES = {
+    BertPreTrainedModel: BertModel,
+    RobertaPreTrainedModel: RobertaModel,
+    ElectraPreTrainedModel: ElectraModel
+}
+
+
+class STEFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input):
+        return (input >= 0.5).float()
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return torch.nn.functional.hardtanh(grad_output)
+
+
+class STEFunction01(STEFunction):
+    @staticmethod
+    def forward(ctx, input):
+        return (input >= 0.1).float()
+
+
+class STEFunction02(STEFunction):
+    @staticmethod
+    def forward(ctx, input):
+        return (input >= 0.2).float()
+
+
+class STEFunction03(STEFunction):
+    @staticmethod
+    def forward(ctx, input):
+        return (input >= 0.3).float()
+
+
+class STEFunction04(STEFunction):
+    @staticmethod
+    def forward(ctx, input):
+        return (input >= 0.4).float()
+
+
+class StraightThroughEstimator(nn.Module):
+    def __init__(self, t):
+        super(StraightThroughEstimator, self).__init__()
+        self.t = t
+
+    def forward(self, x):
+        if self.t == 0.1:
+            x = STEFunction01.apply(x)
+        elif self.t == 0.2:
+            x = STEFunction02.apply(x)
+        elif self.t == 0.3:
+            x = STEFunction03.apply(x)
+        elif self.t == 0.4:
+            x = STEFunction04.apply(x)
+        else:
+            x = STEFunction.apply(x)
+        return x
+
+
+class ElectraPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+        
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+def TransformerForGlue(parent_name):
+    if parent_name not in PARENT_CLASSES:
+        raise ValueError("Unknown model %s" % (parent_name))
+
+    class TransformerForGlue(PARENT_CLASSES[parent_name]):
+        def __init__(self, config):
+            assert config.model_type in PARENT_CLASSES
+            assert self.__class__.__bases__[0] in MODEL_CLASSES
+            super(TransformerForGlue, self).__init__(config)
+            self.model_type = config.model_type
+            self.num_labels = config.num_labels
+            self.dropout_rounds = config.dropout_rounds
+            self.use_tfidf = config.use_tfidf
+            self.tfidf_dim = config.tfidf_dim if hasattr(config, "tfidf_dim") else 0
+            self.no_class_separation = config.no_class_separation
+            self.rescaler_featnum = config.rescaler_featnum
+            self.rescaler_binary = config.rescaler_binary
+            self.rescaler_binary_threshold = config.rescaler_binary_threshold
+            self.config = config
+
+            self.add_module(self.model_type, MODEL_CLASSES[self.__class__.__bases__[0]](config))
+            if self.model_type == "electra":
+                self.pooler = ElectraPooler(config)
+
+            classifier_dropout = (
+                config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+            )
+            self.dropout = nn.Dropout(classifier_dropout)
+
+            if self.use_tfidf:
+                self.classifier = nn.Linear(self.tfidf_dim, config.num_labels)
+            else:
+                self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+            if not self.no_class_separation:
+                for c in range(self.num_labels):
+                    self.add_module("rescaler_" + str(c), nn.Sequential(
+                        nn.Linear(self.rescaler_featnum, self.rescaler_featnum),
+                        nn.BatchNorm1d(self.rescaler_featnum, affine=False),
+                        nn.ReLU(),
+                        nn.Linear(self.rescaler_featnum, 2),
+                        nn.BatchNorm1d(2, affine=False),
+                        nn.Softmax(dim=1),
+                    ))
+            else:
+                self.rescaler = nn.Sequential(
+                    nn.Linear(self.rescaler_featnum, self.rescaler_featnum),
+                    nn.BatchNorm1d(self.rescaler_featnum, affine=False),
+                    nn.ReLU(),
+                    nn.Linear(self.rescaler_featnum, 2),
+                    nn.BatchNorm1d(2, affine=False),
+                    nn.Softmax(dim=1),
+                )
+            if self.rescaler_binary:
+                if not self.no_class_separation:
+                    for c in range(self.num_labels):
+                        getattr(self, "rescaler_" + str(c)).add_module(str(len(getattr(self, "rescaler_" + str(c))) + 1),
+                                                                       StraightThroughEstimator(config.rescaler_binary_threshold))
+                else:
+                    self.rescaler.add_module(str(len(self.rescaler) + 1),
+                                             StraightThroughEstimator(config.rescaler_binary_threshold))
+
+            # Initialize weights and apply final processing
+            self.post_init()
+
+        def forward(
+            self,
+            input_ids = None,
+            attention_mask = None,
+            token_type_ids = None,
+            position_ids = None,
+            head_mask = None,
+            inputs_embeds = None,
+            labels = None,
+            ids = None,
+            feats = None,
+            output_attentions = None,
+            output_hidden_states = None,
+            no_grad = False,
+            suppress_dropout_passes = False,
+            mode = "default"):
+
+            # --------------
+            # Rescaler model
+            # --------------
+
+            if mode == "rescaler":
+                if not self.no_class_separation:
+                    logits = tuple([getattr(self, "rescaler_" + str(c))(feats) for c in range(self.num_labels)])
+                else:
+                    logits = self.rescaler(feats)
+                return (logits,)
+
+            # --------------
+            # Task model
+            # --------------
+
+            dropout_rounds = 1
+            if not suppress_dropout_passes:
+                dropout_rounds += self.dropout_rounds
+
+            dropout_losses = []
+            dropout_logits = []
+            for i in range(dropout_rounds):
+                if i > 0:
+                    torch.set_grad_enabled(False)
+
+                outputs = getattr(self, self.model_type)(
+                    input_ids,
+                    attention_mask=attention_mask,
+                    token_type_ids=token_type_ids,
+                    position_ids=position_ids,
+                    head_mask=head_mask,
+                    inputs_embeds=inputs_embeds,
+                    output_attentions=output_attentions,
+                    output_hidden_states=output_hidden_states)
+
+                if self.model_type == "electra":
+                    pooled_output = self.pooler(outputs[0])
+                else:
+                    pooled_output = outputs[1]
+
+                pooled_output = self.dropout(pooled_output)
+                if self.use_tfidf:
+                    if i > 0:
+                        logits = self.classifier(self.dropout(feats))
+                    else:
+                        logits = self.classifier(feats)
+                else:
+                    logits = self.classifier(pooled_output)
+
+                if labels is not None:
+                    loss_fct = CrossEntropyLoss(reduction='none')
+                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+                dropout_losses.append(loss)
+                dropout_logits.append(logits)
+
+                torch.set_grad_enabled(True)
+
+            outputs = (torch.stack(dropout_losses),
+                       torch.stack(dropout_logits),) + outputs[2:]
+
+            return outputs
+
+    return TransformerForGlue
diff --git a/storm.py b/storm.py
new file mode 100644
index 0000000000000000000000000000000000000000..96916bdf7d8b5cd2deb4d158c546688164b344e7
--- /dev/null
+++ b/storm.py
@@ -0,0 +1,787 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import autograd_hacks
+import copy
+import gzip
+import higher
+import json
+import logging
+import math
+import os
+import pickle
+import random
+import re
+import torch
+import transformers
+import warnings
+
+import numpy as np
+
+from sklearn.metrics import (f1_score, matthews_corrcoef)
+from tensorboardX import (SummaryWriter)
+from torch.utils.data import (DataLoader, WeightedRandomSampler)
+from tqdm.auto import (tqdm)
+from transformers import (
+    BertConfig, BertModel, BertTokenizer,
+    RobertaConfig, RobertaModel, RobertaTokenizer,
+    ElectraConfig, ElectraModel, ElectraTokenizer)
+
+from agra import (AGRA)
+from modeling import (TransformerForGlue)
+from utils_gpu import (from_device, to_device)
+from utils_storm import (print_header, Results, Filter)
+
+warnings.filterwarnings("ignore", "Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`\.", UserWarning)
+
+MODEL_CLASSES = {
+    'bert': (BertConfig, TransformerForGlue('bert'), BertModel, BertTokenizer),
+    'roberta': (RobertaConfig, TransformerForGlue('roberta'), RobertaModel, RobertaTokenizer),
+    'electra': (ElectraConfig, TransformerForGlue('electra'), ElectraModel, ElectraTokenizer),
+}
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+    def list_of_strings(arg):
+        return arg.split(',')
+    
+    parser = argparse.ArgumentParser(description="STORM for text classification tasks.")
+
+    # Required parameters
+    parser.add_argument("--task_name", type=str, default=None, required=True,
+                        help="The name of the glue task to train on.")
+    parser.add_argument("--model_type", type=str, default=None, required=True,
+                        help="Model type.",
+                        choices=list(MODEL_CLASSES.keys()))
+    parser.add_argument("--model_name_or_path", type=str, default=None, required=True,
+                        help="Path to pretrained model or model identifier from huggingface.co/models.")
+    parser.add_argument("--train_file", type=str, default=None, required=True,
+                        help="A tsv file containing the training data.")
+    parser.add_argument("--validation_file", type=str, default=None, required=True,
+                        help="A tsv file containing the validation data.")
+
+    # Other parameters
+    parser.add_argument("--test_file", type=str, default=None,
+                        help="A tsv file containing the test data. Loading a test file overwrites --crossvalid_fold.")
+    parser.add_argument("--max_length", type=int, default=128,
+                        help="The maximum total input sequence length after tokenization. "
+                             "Sequences longer than this will be truncated, sequences shorter "
+                             "will be padded if `--pad_to_max_lengh` is passed.")
+    parser.add_argument("--pad_to_max_length", action="store_true",
+                        help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.")
+    parser.add_argument("--train_batch_size", type=int, default=8,
+                        help="Batch size for the training dataloader.")
+    parser.add_argument("--eval_batch_size", type=int, default=8,
+                        help="Batch size for the evaluation dataloader.")
+    parser.add_argument("--learning_rate", type=float, default=5e-5,
+                        help="Initial learning rate (after the potential warmup period) to use.")
+    parser.add_argument("--rescaler_learning_rate", type=float, default=1e-2,
+                        help="Initial learning rate (after the potential warmup period) to use.")
+    parser.add_argument("--weight_decay", type=float, default=0.0,
+                        help="Weight decay to use.")
+    parser.add_argument("--num_train_epochs", type=int, default=3,
+                        help="Total number of training epochs to perform.")
+    parser.add_argument("--max_train_steps", type=int, default=None,
+                        help="Total number of training steps to perform. If provided, overrides num_train_epochs.")
+    parser.add_argument("--lr_scheduler_type", type=str, default="constant",
+                        help="The scheduler type to use.",
+                        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"])
+    parser.add_argument("--warmup_proportion", type=float, default=0.0,
+                        help="Proportion of steps for the warmup in the lr scheduler.")
+    parser.add_argument("--output_dir", type=str, default=None,
+                        help="Where to store the final model.")
+    parser.add_argument("--seed", type=int, default=42,
+                        help="A seed for reproducible training.")
+    parser.add_argument('--local_files_only', action='store_true',
+                        help="Whether to only load local model files (useful when working offline).")
+    parser.add_argument('--logging_steps', type=int, default=100,
+                        help="Log every X updates steps.")
+    parser.add_argument("--save_checkpoints", action='store_true',
+                        help="When set, saves model checkpoint after every epoch.")
+    parser.add_argument("--save_stats", action='store_true',
+                        help="When set, saves detailed training statistics as gzipped json.")
+    parser.add_argument('--stats_window', type=int, default=1,
+                        help="Window size for sample statistics memory (in number of batches).")
+    parser.add_argument("--simulate_only", action='store_true',
+                        help="When set, rescaling is not actually applied (but still learned).")
+    parser.add_argument('--dropout_rounds', type=int, default=2,
+                        help="Number of additional forward passes to compute sample statistics.")
+    parser.add_argument("--no_cat", action='store_true',
+                        help="When set, sample statistics are computed by label-prediction agreement.")
+    parser.add_argument("--agra", action='store_true',
+                        help="Use original AGRA method instead of rescaler. "
+                             "AGRA does not use meta learning.")
+    parser.add_argument("--agra_weighted", action='store_true',
+                        help="When set, uses class weighting to sample comparison batches.")
+    parser.add_argument("--agra_loss", type=str, default="CE",
+                        help="Comparison loss for AGRA.",
+                        choices=["CE", "F1"])
+    parser.add_argument("--agra_layer_groups", type=list_of_strings, default='classifier',
+                        help="Layer groups to consider for AGRA.")
+    parser.add_argument("--use_tfidf", action='store_true',
+                        help="When set, loads TF-IDF features instead of using a transformer encoder. Requires --tfidf_path.")
+    parser.add_argument("--tfidf_path", type=str, default=None,
+                        help="Path to TF-IDF features to be loaded when using --use_tdidf.")
+    parser.add_argument("--evaluate_rescaling", action='store_true',
+                        help="When set, evaluate rescaling performance given a dataset with corruption labels "
+                              "that indicate whether the original label is noisy. "
+                              "Note: You need to ensure that such labels exist in the data, "
+                              "as the code does not assert this automatically.")
+    parser.add_argument('--crossvalid_fold', type=int, default=-1,
+                        help="When set to 0 or 1, uses 2-fold cross-validation given the --validation_file. "
+                             "Usage: Run this script separately once for fold 0 and 1 "
+                             "and then manually pool your results accordingly.",
+                        choices=[-1, 0, 1])
+    parser.add_argument("--meta_dataset", type=str, default="train",
+                        help="Dataset used for the meta update step (default for STORM: 'train').",
+                        choices=["train", "eval"])
+    parser.add_argument('--meta_innerloop_rounds', type=int, default=1,
+                        help="Number of inner loop traversals for meta learning.")
+    parser.add_argument("--rescaler_binary", action='store_true',
+                        help="When set, rescaler will produce binary loss weights, i.e., 0 or 1.")
+    parser.add_argument("--rescaler_binary_threshold", type=float, default=0.5,
+                        help="Confidence threshold for binary rescaler.",
+                        choices=[0.1, 0.2, 0.3, 0.4, 0.5])
+    parser.add_argument("--rescaler_feats", type=str, default="default",
+                        help="Feature set for rescaler.",
+                        choices=["default", "loss"])
+    parser.add_argument("--rescaler_feats_no_cat", action='store_true',
+                        help="When set, cat is not added to the feature set of the rescaler.")
+    parser.add_argument("--no_class_separation", action='store_true',
+                        help="When set, sample statistics are not separated by target classes.")
+    parser.add_argument("--no_meta_loss_rescaling", action='store_true',
+                        help="When set, no meta loss rescaling is used.")
+
+    args = parser.parse_args()
+
+    assert not args.use_tfidf or args.tfidf_path is not None
+
+    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    return args
+
+
+def evaluate_task(pred, ref, metrics):
+    results = {}
+    if "accuracy" in metrics:
+        results['accuracy'] = (np.array(pred) == np.array(ref)).mean()
+    if "f1" in metrics:
+        results['f1'] = f1_score(y_true=ref, y_pred=pred)
+    if "matthews_correlation" in metrics:
+        results['matthews_correlation'] = matthews_corrcoef(y_true=ref, y_pred=pred)
+    return results
+
+
+def load_raw_dataset(input_file, data_specs, has_corruption_labels=False):
+    def label_map(label, lbl_map):
+        if lbl_map is not None:
+            return lbl_map[label]
+        else:
+            return int(label)
+
+    raw_data = []
+    id_to_pos = {}
+    unique_labels = set()
+    corrupted_ids = []
+    with open(input_file, "r", encoding='utf-8-sig') as f:
+        for l_itr, line in enumerate(f):
+            data_point = {'labels': None, 'sentence1': None, 'sentence2': None, 'ids': None}
+            if data_specs['has_header'] and l_itr == 0:
+                continue
+            raw_data_point = line.strip().split('\t')
+            data_point['labels'] = label_map(raw_data_point[data_specs['label']], data_specs['label_map'])
+            data_point['sentence1'] = raw_data_point[data_specs['sentence1']]
+            data_point['sentence2'] = raw_data_point[data_specs['sentence2']] if data_specs['sentence2'] is not None else None
+            data_point['ids'] = l_itr - 1 if data_specs['has_header'] else l_itr
+            id_to_pos[data_point['ids']] = len(raw_data)
+            raw_data.append(data_point)
+            if has_corruption_labels and raw_data_point[-1] == 'True':
+                corrupted_ids.append(data_point['ids'])
+            unique_labels.add(data_point['labels'])
+    return raw_data, list(unique_labels), id_to_pos, corrupted_ids
+
+
+def process_raw_dataset(dataset, tokenizer, padding, max_length):
+    processed_dataset = []
+    for i in tqdm(dataset, desc="Running tokenizer on dataset"):
+        # Tokenize the texts
+        text = ((i['sentence1'],) if i['sentence2'] is None else (i['sentence1'], i['sentence2']))
+        result = tokenizer(*text, padding=padding, max_length=max_length, truncation=True)
+        if 'ids' in i:
+            result['ids'] = i['ids']
+        if 'labels' in i:
+            result['labels'] = i['labels']
+        processed_dataset.append(result)
+    return processed_dataset
+
+
+# For TF-IDF features. May be removed later.
+def add_feats_to_raw_dataset(dataset, feats_path):
+    feats_dict = pickle.load(open(feats_path, "rb"))
+    for i in dataset:
+        ids = str(i['ids'])
+        if ids in feats_dict:
+            i['feats'] = feats_dict[ids].tolist()
+        else:
+            raise Exception("Feats not found")
+
+
+def evaluate_filtering(corrupted_ids, filter_list_by_epoch, epoch, data_size):
+    prediction_size = len(filter_list_by_epoch[epoch])
+    c_tp = 0
+    c_fn = 0
+    for i in corrupted_ids:
+        if i in filter_list_by_epoch[epoch]:
+            c_tp += 1
+        else:
+            c_fn += 1
+    c_fp = prediction_size - c_tp
+    c_tn = data_size - c_tp - c_fp - c_fn
+    precision = c_tp / (c_tp + c_fp + 1e-8)
+    recall = c_tp / (c_tp + c_fn + 1e-8)
+    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
+    acc = (c_tp + c_tn) / (c_tp + c_fp + c_tn + c_fn + 1e-8)
+    specificity = c_tn / (c_tn + c_fp + 1e-8)
+    return precision, recall, f1, specificity, acc, c_tp, c_tn, c_fp, c_fn
+
+
+def call_rescaler(args, model, epoch_filter, step_results, kls, ovls, per_example_loss, batch, num_labels, corrupted_ids):
+    d_cat = int(not args.no_cat) * (2 - step_results['agreement'].type(torch.int)) # 0, 1 or 2 (agree or disagree)
+    rlbls = torch.full(batch['labels'].size(), 0) if args.no_class_separation else from_device(batch['labels'])
+    if args.rescaler_feats == "default":
+        feats = torch.cat((step_results['losses_means'].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_means_mean')[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_means_std')[rlbls, d_cat].reshape(-1,1),
+                           step_results['losses_stds'].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_stds_mean')[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_stds_std')[rlbls, d_cat].reshape(-1,1),
+                           step_results['probs_means'].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_means_mean')[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_means_std')[rlbls, d_cat].reshape(-1,1),
+                           step_results['probs_stds'].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_stds_mean')[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_stds_std')[rlbls, d_cat].reshape(-1,1),
+                           torch.tensor(kls).reshape(-1,1),
+                           epoch_filter.get_stats_tensor('kl_mean')[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('kl_std')[rlbls, d_cat].reshape(-1,1),
+                           torch.tensor(ovls).reshape(-1,1),
+                           epoch_filter.get_stats_tensor('ovl_mean')[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('ovl_std')[rlbls, d_cat].reshape(-1,1)),
+                          dim=1).type(torch.float).to(args.device)
+    elif args.rescaler_feats == "loss":
+        feats = per_example_loss.reshape(-1,1).detach().clone() # Loss w/o grad
+    else:
+        raise Exception("Unknown rescaler features.")
+    if not args.rescaler_feats_no_cat:
+        feats = torch.cat((feats, step_results['agreement'].unsqueeze(1).type(torch.float).to(args.device)), dim=1)
+    rescaler_outputs = model(feats=feats, mode="rescaler")
+    if not args.no_class_separation:
+        rescaler_outputs = rescaler_outputs[0]
+        new_rescaler_outputs = rescaler_outputs[0] * (step_results['labels'] == 0).unsqueeze(1).to(args.device)
+        for c in range(1, num_labels):
+            new_rescaler_outputs += rescaler_outputs[c] * (step_results['labels'] == c).unsqueeze(1).to(args.device)
+        rescaler_outputs = (new_rescaler_outputs,)
+    rescaler_outputs = rescaler_outputs[0][:,1]
+    return rescaler_outputs
+
+
+def main():
+    args = parse_args()
+
+    # Set up logging, print header
+    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+                        datefmt="%m/%d/%Y %H:%M:%S",
+                        level=logging.INFO)
+    logger.setLevel(logging.INFO)
+    transformers.utils.logging.set_verbosity_info()
+    tb_writer = SummaryWriter()
+
+    print_header()
+
+    for a in vars(args):
+        logger.info("{:40s} {:s}".format(a, str(getattr(args, a))))
+    logger.info("")
+        
+    # If passed along, set the training seed now.
+    if args.seed is not None:
+        transformers.set_seed(args.seed)
+
+    # Handle the repository creation
+    if args.output_dir is not None:
+        os.makedirs(args.output_dir, exist_ok=True)
+
+    # Loading the datasets.
+    data_specs = {'rte': {'has_header': True, 'label': 3, 'sentence1': 1, 'sentence2': 2, 'label_map': {'not_entailment': 0, 'entailment': 1}},
+                  'mrpc': {'has_header': True, 'label': 0, 'sentence1': 3, 'sentence2': 4, 'label_map': None},
+                  'cola': {'has_header': False, 'label': 1, 'sentence1': 3, 'sentence2': None, 'label_map': None},
+                  'youtube': {'has_header': True, 'label': 2, 'sentence1': 1, 'sentence2': None, 'label_map': None},
+                  'sms': {'has_header': True, 'label': 2, 'sentence1': 1, 'sentence2': None, 'label_map': None}}
+    raw_datasets = {}
+    raw_datasets["train"], label_list, train_id_to_pos, corrupted_ids = load_raw_dataset(args.train_file,
+                                                                                         data_specs[args.task_name],
+                                                                                         has_corruption_labels=args.evaluate_rescaling)
+    raw_datasets["validation"], _, _, _ = load_raw_dataset(args.validation_file, data_specs[args.task_name])
+    if args.test_file is not None:
+        raw_datasets["test"], _, _, _ = load_raw_dataset(args.test_file, data_specs[args.task_name])
+    elif args.crossvalid_fold == 0:
+        raw_datasets["test"] = raw_datasets["validation"][int(len(raw_datasets["validation"]) / 2):]
+        raw_datasets["validation"] = raw_datasets["validation"][:int(len(raw_datasets["validation"]) / 2)]
+    elif args.crossvalid_fold == 1:
+        raw_datasets["test"] = raw_datasets["validation"][:int(len(raw_datasets["validation"]) / 2)]
+        raw_datasets["validation"] = raw_datasets["validation"][int(len(raw_datasets["validation"]) / 2):]
+
+
+    # Labels
+    label_list.sort()  # Let's sort it for determinism
+    num_labels = len(label_list)
+
+    config_class, model_class, default_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+
+    # Load pretrained model and tokenizer
+    config = config_class.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name, local_files_only=args.local_files_only)
+    config.dropout_rounds = args.dropout_rounds
+    config.use_tfidf = args.use_tfidf
+    config.no_class_separation = args.no_class_separation
+    config.rescaler_binary = args.rescaler_binary
+    config.rescaler_binary_threshold = args.rescaler_binary_threshold
+    if args.rescaler_feats == "default":
+        config.rescaler_featnum = 18
+    elif args.rescaler_feats == "loss":
+        config.rescaler_featnum = 1
+    else:
+        raise Exception("Unknown rescaler features.")
+    if not args.rescaler_feats_no_cat:
+        config.rescaler_featnum += 1
+
+    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, local_files_only=args.local_files_only)
+    padding = "max_length" if args.pad_to_max_length else False
+    train_dataset = process_raw_dataset(raw_datasets["train"], tokenizer, padding, args.max_length)
+    eval_dataset = process_raw_dataset(raw_datasets["validation"], tokenizer, padding, args.max_length)
+    if args.test_file is not None or args.crossvalid_fold > -1:
+        test_dataset = process_raw_dataset(raw_datasets["test"], tokenizer, padding, args.max_length)
+    if args.use_tfidf:
+        add_feats_to_raw_dataset(train_dataset, os.path.join(args.tfidf_path, "train_feats.pickle"))
+        add_feats_to_raw_dataset(eval_dataset, os.path.join(args.tfidf_path, "valid_feats.pickle"))
+        if args.test_file is not None:
+            add_feats_to_raw_dataset(test_dataset, os.path.join(args.tfidf_path, "test_feats.pickle"))
+        elif args.crossvalid_fold > -1:
+            add_feats_to_raw_dataset(test_dataset, os.path.join(args.tfidf_path, "valid_feats.pickle"))
+        config.tfidf_dim = len(train_dataset[0]['feats'])
+    model = model_class.from_pretrained(
+        args.model_name_or_path,
+        from_tf=bool(".ckpt" in args.model_name_or_path),
+        config=config,
+        local_files_only=args.local_files_only)
+
+    # DataLoaders creation:
+    if args.pad_to_max_length:
+        # If padding was already done ot max length, we use the default data collator that will just convert everything
+        # to tensors.
+        data_collator = transformers.default_data_collator
+    else:
+        # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
+        # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
+        # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
+        data_collator = transformers.DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(None))
+
+    eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.eval_batch_size)
+    if args.test_file is not None or args.crossvalid_fold > -1:
+        test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.eval_batch_size)
+
+    # Optimizer
+    # Split weights in two groups, one with weight decay and the other not.
+    no_decay = ["bias", "LayerNorm.weight"]
+    inner_optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and not "rescaler" in n],
+            "weight_decay": args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and not "rescaler" in n],
+            "weight_decay": 0.0,
+        },
+    ]
+    outer_optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and "rescaler" in n],
+            "weight_decay": args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and "rescaler" in n],
+            "weight_decay": 0.0,
+        },
+    ]
+    inner_optimizer = torch.optim.Adam(inner_optimizer_grouped_parameters, lr=args.learning_rate)
+    meta_optimizer = torch.optim.Adam(outer_optimizer_grouped_parameters, lr=args.rescaler_learning_rate)
+
+    model.to(args.device)
+
+    # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
+    # shorter in multiprocess)
+
+    # Scheduler and math around the number of training steps.
+    num_update_steps_per_epoch = math.ceil(len(train_dataset) / args.train_batch_size)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # t_total
+    else:
+        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    # This would jointly regulate the lr for inner_optimizer and meta_optimizer
+    # if both optimizers would use the same grouped_parameters, i.e.,
+    # the lr is tied to the parameters, not to the optimizer.
+    lr_scheduler = transformers.get_scheduler(
+        name=args.lr_scheduler_type,
+        optimizer=inner_optimizer,
+        num_warmup_steps=int(args.max_train_steps * args.warmup_proportion),
+        num_training_steps=args.max_train_steps,
+    )
+
+    if args.agra:
+        agra = AGRA(args.agra_loss,
+                    num_labels,
+                    args.agra_weighted,
+                    train_dataset,
+                    model=model,
+                    device=args.device,
+                    window_size=args.stats_window)
+
+    # Train!
+    logger.info("***** Running training *****")
+    logger.info(f"  Num examples = {len(train_dataset)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Total train batch size = {args.train_batch_size}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+
+    progress_bar = tqdm(range(args.max_train_steps))
+    completed_steps = 0
+    tr_loss, logging_loss = 0.0, 0.0
+
+    stat_list = {}
+    filter_list_by_epoch = {-1: []}
+    for epoch in range(args.num_train_epochs):
+        train_dataloader = DataLoader(
+            train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.train_batch_size
+        )
+
+        train_inner_dataloader = DataLoader(
+            train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.train_batch_size
+        )
+
+        # Meta dataloader
+        if args.meta_dataset == "eval":
+            meta_dataset = eval_dataset
+        else:
+            meta_dataset = train_dataset
+        meta_sampler = WeightedRandomSampler(torch.ones(len(meta_dataset)), len(meta_dataset))
+        meta_dataloader = DataLoader(
+            meta_dataset, collate_fn=data_collator, batch_size=args.train_batch_size, sampler=meta_sampler)
+
+        if args.agra:
+            agra.build_dataloader(data_collator, args.train_batch_size)
+
+        model.train()
+
+        filter_list_by_epoch[epoch] = []
+        # Setting this to 1 will cause Filter to not separate stats by target class
+        num_filter_labels = 1 if args.no_class_separation else num_labels 
+        epoch_filter = Filter(num_filter_labels,
+                              args.stats_window,
+                              args.train_batch_size)
+        if not args.no_meta_loss_rescaling:
+            meta_epoch_filter = Filter(num_filter_labels,
+                                       args.stats_window,
+                                       args.train_batch_size)
+        for step, orig_batch in enumerate(train_dataloader):
+            orig_batch = to_device(orig_batch, args.device)
+            batch_size = orig_batch['input_ids'].size(0)
+
+            meta_batch = next(iter(meta_dataloader))
+            meta_batch = to_device(meta_batch, args.device)
+            meta_optimizer.zero_grad()
+
+            with higher.innerloop_ctx(model, inner_optimizer, copy_initial_weights=False, track_higher_grads=True) as (fmodel, diffopt):
+                for ir in range(args.meta_innerloop_rounds):
+                    if ir > 0:
+                        batch = next(iter(train_inner_dataloader))
+                        batch = to_device(batch, args.device)
+                    else:
+                        batch = orig_batch
+
+                    if args.agra:
+                        grad_scores = agra.agra_step(batch, args.agra_layer_groups)
+
+                    outputs = fmodel(**batch)
+
+                    per_example_loss = outputs[0][0] # This is the loss we use for backprop
+                    outputs = from_device(outputs)
+                    labels = from_device(batch['labels'])
+                    step_results = Results(outputs[0], outputs[1], labels)
+
+                    if args.agra:
+                        step_results['grad_scores'] = grad_scores
+
+                    orig_loss = per_example_loss.sum().item()
+
+                    # Get batch loss/grad statistics, accumulated across samples
+                    epoch_filter.update_batch_stats(step_results)
+
+                    # Get stats
+                    kls, ovls = epoch_filter.get_sample_stats(batch,
+                                                              step_results,
+                                                              args.no_cat)
+
+                    # ---------------------------
+                    # Rescale sample losses
+                    # ---------------------------
+
+                    if args.agra:
+                        # AGRA baseline
+                        inner_rescaler_outputs = (grad_scores >= 0).float().to(args.device)
+                    else:
+                        # Get weights for each sample
+                        inner_rescaler_outputs = call_rescaler(args, fmodel,
+                                                               epoch_filter, step_results,
+                                                               kls, ovls,
+                                                               per_example_loss, batch,
+                                                               num_labels, corrupted_ids)
+
+                    # Rescale
+                    if not args.simulate_only:
+                        per_example_loss *= inner_rescaler_outputs
+
+                    # Get stats
+                    if ir == 0:
+                        filter_ids = batch['ids'][(inner_rescaler_outputs < 0.5).cpu()].tolist()
+                        filter_list_by_epoch[epoch] += filter_ids
+
+                        for l_itr in range(batch_size):
+                            batch_id = batch['ids'][l_itr].item()
+                            label = batch['labels'][l_itr].item()
+
+                            d_cat = int(not args.no_cat) * (2 - int(step_results['agreement'][l_itr])) # 0, 1 or 2 (agree or disagree)
+                            rlbl = 0 if args.no_class_separation else label
+
+                            # Collect statistics
+                            if batch_id not in stat_list:
+                                stat_list[batch_id] = []
+                            stat_list[batch_id].append(
+                                {"is_filtered": batch_id in filter_ids,
+                                 "epoch": epoch,
+                                 "label": label,
+                                 "weight": inner_rescaler_outputs[l_itr].item(),
+                                 "guid": batch_id,
+                                 "d_preds": step_results['preds'][:,l_itr].tolist(),
+                                 "d_losses": step_results['losses'][:,l_itr].tolist(),
+                                 "d_losses_mean": step_results['losses_means'][l_itr].item(),
+                                 "d_losses_std": step_results['losses_stds'][l_itr].item(),
+                                 "d_losses_mean_batch": epoch_filter.get_stats('loss_means_mean')[rlbl][d_cat],
+                                 "d_losses_std_batch": epoch_filter.get_stats('loss_stds_mean')[rlbl][d_cat],
+                                 "d_probs": step_results['probs'][:,l_itr].max(1)[0].tolist(),
+                                 "d_probs_mean": step_results['probs_means'][l_itr].item(),
+                                 "d_probs_std": step_results['probs_stds'][l_itr].item(),
+                                 "grad_scores": step_results['grad_scores'][l_itr].item() if 'grad_scores' in step_results.examples else 0.0,
+                                 "d_probs_mean_batch": epoch_filter.get_stats('prob_means_mean')[rlbl][d_cat],
+                                 "d_probs_std_batch": epoch_filter.get_stats('prob_means_std')[rlbl][d_cat],
+                                 "g_kl_div": kls[l_itr],
+                                 "g_kl_div_batch": epoch_filter.get_stats('kl_mean')[rlbl][d_cat],
+                                 "ovl": ovls[l_itr],
+                                 "ovl_batch": epoch_filter.get_stats('ovl_mean')[rlbl][d_cat]})
+
+                    loss = per_example_loss.sum()
+                    tr_loss += loss.item()
+
+                    diffopt.step(loss) # Perform a training step
+
+                    if ir > 0:
+                        batch = from_device(batch)
+
+                # End of inner loop(s)
+
+                # Meta update
+                if not args.agra:
+                    meta_outputs = fmodel(**meta_batch)
+                    meta_per_example_loss = meta_outputs[0][0] # This is the loss we use for backprop
+                    meta_outputs = from_device(meta_outputs)
+
+                    if not args.no_meta_loss_rescaling:
+                        meta_step_results = Results(meta_outputs[0], meta_outputs[1], from_device(meta_batch['labels']))
+                        meta_epoch_filter.update_batch_stats(meta_step_results)
+                        meta_kls, meta_ovls = meta_epoch_filter.get_sample_stats(meta_batch,
+                                                                                 meta_step_results,
+                                                                                 args.no_cat)
+                        rescaler_outputs = call_rescaler(args, fmodel,
+                                                         meta_epoch_filter, meta_step_results,
+                                                         meta_kls, meta_ovls,
+                                                         meta_per_example_loss, meta_batch,
+                                                         num_labels, corrupted_ids)
+                        if not args.simulate_only:
+                            meta_per_example_loss *= rescaler_outputs
+
+                    meta_loss = meta_per_example_loss.sum()
+                    meta_loss.backward()
+
+            # Meta update
+            if not args.agra:
+                meta_optimizer.step()
+                meta_optimizer.zero_grad()
+
+            # Update model parameters
+            model.roberta.load_state_dict(fmodel.roberta.state_dict())
+            model.classifier.load_state_dict(fmodel.classifier.state_dict())
+
+            lr_scheduler.step()
+            progress_bar.update(1)
+            completed_steps += 1
+
+            orig_batch = from_device(orig_batch)
+
+            # Log metrics
+            if tb_writer and args.logging_steps > 0 and completed_steps % args.logging_steps == 0:
+                tb_writer.add_scalar('lr', lr_scheduler.get_last_lr()[0], completed_steps)
+                tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, completed_steps)
+                logging_loss = tr_loss
+
+            if completed_steps >= args.max_train_steps:
+                break
+
+        model.eval()
+        eval_metric = {"predictions": [], "references": []}
+        for step, batch in enumerate(eval_dataloader):
+            batch = to_device(batch, args.device)
+            outputs = model(**batch, suppress_dropout_passes=True)
+            predictions = outputs[1][0].argmax(dim=-1)
+            eval_metric["predictions"] += predictions.tolist()
+            eval_metric["references"] += batch["labels"].tolist()
+        if args.test_file is not None or args.crossvalid_fold > -1:
+            test_metric = {"predictions": [], "references": []}
+            for step, batch in enumerate(test_dataloader):
+                batch = to_device(batch, args.device)
+                outputs = model(**batch, suppress_dropout_passes=True)
+                predictions = outputs[1][0].argmax(dim=-1)
+                test_metric["predictions"] += predictions.tolist()
+                test_metric["references"] += batch["labels"].tolist()
+
+        logger.info("")
+        logger.info("Filtered samples (loss weight < 0.5) in epoch %s: %s of %s" % (epoch,
+                                                                                    len(filter_list_by_epoch[epoch]),
+                                                                                    len(train_dataset)))
+        tb_writer.add_scalars('filtered_samples', {'Filtered samples': len(filter_list_by_epoch[epoch]),
+                                                   'Total': len(train_dataset)}, epoch)
+
+        # Evaluate rescaling/filtering performance
+        if args.evaluate_rescaling:
+            (precision, recall,
+             f1, specificity, acc,
+             c_tp, c_tn, c_fp, c_fn) = evaluate_filtering(corrupted_ids, filter_list_by_epoch, epoch, len(train_dataset))
+            logger.info("Filter performance in epoch %s: " \
+                        "Precision %.4f, Recall %.4f, F1 %.4f, Accuracy %.4f, " \
+                        "Specificity: %.4f (TP: %s, TN: %s, FP: %s, FN: %s)" % (epoch, precision,
+                                                                                recall, f1,
+                                                                                acc, specificity,
+                                                                                c_tp, c_tn, c_fp, c_fn))
+            tb_writer.add_scalars('filter_performance', {'Precision': precision,
+                                                         'Recall': recall,
+                                                         'F1': f1,
+                                                         'Specificity': specificity,
+                                                         'Accuracy': acc}, epoch)
+
+        # Evaluate task performance
+        eval_result = evaluate_task(eval_metric['predictions'], eval_metric['references'], ["accuracy", "f1", "matthews_correlation"])
+        logger.info(f"eval epoch {epoch}: {eval_result}")
+        tb_writer.add_scalars('eval_metric', eval_result, epoch)
+        if args.test_file is not None or args.crossvalid_fold > -1:
+            test_result = evaluate_task(test_metric['predictions'], test_metric['references'], ["accuracy", "f1", "matthews_correlation"])
+            logger.info(f"test epoch {epoch}: {test_result}")
+            tb_writer.add_scalars('test_metric', test_result, epoch)
+
+        # Save model checkpoint
+        if args.output_dir is not None and args.save_checkpoints:
+            output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(completed_steps))
+            logger.info("Saving model checkpoint to %s", output_dir)
+            if not os.path.exists(output_dir):
+                os.makedirs(output_dir)
+            model.save_pretrained(output_dir)
+
+    # After training, print detailed statistics
+    cc_stats = {e: {'cnt': 1e-8, 'filtered': 0} for e in range(args.num_train_epochs)}
+    cw_stats = copy.deepcopy(cc_stats)
+    wc_stats = copy.deepcopy(cc_stats)
+    ww_stats = copy.deepcopy(cc_stats)
+    for s in stat_list:
+        # Reformat stats
+        stats = {k: None for k in stat_list[s][0].keys()}
+        for k in stat_list[s][0]:
+            if k in ["epoch", "is_filtered", "label", "d_preds"]:
+                stats[k] = [i[k] for i in stat_list[s]]
+            elif k in ["d_losses", "d_probs"]:
+                stats[k] = [["%.4f" % (j) for j in i[k]] for i in stat_list[s]]
+            else:
+                stats[k] = ["%.4f" % (i[k]) for i in stat_list[s]]
+        stats['filtered_epochs'] = [i_itr for i_itr, i in enumerate(stat_list[s]) if i['is_filtered']]
+
+        # Collect global stats
+        for ep in stats['epoch']:
+            majority_pred = max(set(stats['d_preds'][ep]), key=stats['d_preds'][ep].count)
+            majority_pred_count = stats['d_preds'][ep].count(majority_pred)
+            # In case multiple classes have the majority count, take the class of the first prediction
+            if stats['d_preds'][ep].count(stats['d_preds'][ep][0]) == majority_pred_count:
+                majority_pred = stats['d_preds'][ep][0]
+
+            stat_dict = None
+            if s not in corrupted_ids and majority_pred == stats['label'][0]:
+                stat_dict = cc_stats
+            elif s not in corrupted_ids and majority_pred != stats['label'][0]:
+                stat_dict = cw_stats
+            elif s in corrupted_ids and majority_pred != stats['label'][0]:
+                stat_dict = wc_stats
+            elif s in corrupted_ids and majority_pred == stats['label'][0]:
+                stat_dict = ww_stats
+            stat_dict[ep]['cnt'] += 1
+            stat_dict[ep]['filtered'] += 1 if ep in stats['filtered_epochs'] else 0
+
+    logger.info("")
+    logger.info("Filtered samples (loss weight < 0.5) per epoch by label-prediction agreement:")
+    for e in range(args.num_train_epochs):
+        logger.info("%d "
+                    "cc: %d of %d, cw: %d of %d, "
+                    "wc: %d of %d, ww: %d of %d" % (e,
+                                                    cc_stats[e]['filtered'], cc_stats[e]['cnt'],
+                                                    cw_stats[e]['filtered'], cw_stats[e]['cnt'],
+                                                    wc_stats[e]['filtered'], wc_stats[e]['cnt'],
+                                                    ww_stats[e]['filtered'], ww_stats[e]['cnt']))
+
+    if tb_writer:
+        tb_writer.close()
+
+    if args.output_dir is not None and args.save_checkpoints:
+        logger.info("Saving model checkpoint to %s", args.output_dir)
+        model.save_pretrained(args.output_dir)
+        tokenizer.save_pretrained(args.output_dir)
+    torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
+
+    if args.output_dir is not None and args.save_stats:
+        with gzip.open(os.path.join(args.output_dir, "stat_list.json.gz"), "w") as f:
+            f.write(json.dumps(stat_list, indent=2).encode('utf-8'))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/storm_dst.py b/storm_dst.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af0d3b6e3b859aaec55b751a71c94d28c2fb71e
--- /dev/null
+++ b/storm_dst.py
@@ -0,0 +1,917 @@
+# coding=utf-8
+#
+# Copyright 2020-2024 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of BERT-DST
+# (arXiv:1907.03040)
+# Part of this code is based on the source code of Transformers
+# (arXiv:1910.03771)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import autograd_hacks
+import copy
+import glob
+import gzip
+import higher
+import json
+import logging
+import math
+import os
+import pickle
+import random
+import re
+import sys
+import torch
+import transformers
+import warnings
+
+import numpy as np
+
+from accelerate import (Accelerator)
+from tensorboardX import (SummaryWriter)
+from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler)
+from tqdm import (tqdm, trange)
+from transformers import (
+    WEIGHTS_NAME,
+    BertConfig, BertModel, BertTokenizer,
+    RobertaConfig, RobertaModel, RobertaTokenizer,
+    ElectraConfig, ElectraModel, ElectraTokenizer)
+
+from dst.agra_dst import AGRA
+from dst.modeling_dst import (TransformerForDST)
+from dst.utils_storm_dst import (Results, Filter, batch_to_dict)
+from utils_gpu import (from_device, to_device)
+from utils_storm import (print_header)
+
+trippy_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'trippy'))
+sys.path.insert(0, trippy_path)
+
+from trippy.data_processors import (PROCESSORS)
+from trippy.run_dst import (eval_metric, predict_and_format)
+from trippy.utils_dst import (convert_examples_to_features)
+from trippy.tensorlistdataset import (TensorListDataset)
+
+warnings.filterwarnings("ignore", "Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`\.", UserWarning)
+
+MODEL_CLASSES = {
+    'bert': (BertConfig, TransformerForDST('bert'), BertModel, BertTokenizer),
+    'roberta': (RobertaConfig, TransformerForDST('roberta'), RobertaModel, RobertaTokenizer),
+    'electra': (ElectraConfig, TransformerForDST('electra'), ElectraModel, ElectraTokenizer),
+}
+
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+    def list_of_strings(arg):
+        return arg.split(',')
+    
+    parser = argparse.ArgumentParser(description="STORM for dialogue state tracking with TripPy.")
+
+    # Required parameters
+    parser.add_argument("--task_name", type=str, default=None, required=True,
+                        help="Name of the task (e.g., multiwoz21).")
+    parser.add_argument("--data_dir", type=str, default=None, required=True,
+                        help="Task database.")
+    parser.add_argument("--dataset_config", type=str, default=None, required=True,
+                        help="Dataset configuration file.")
+    parser.add_argument("--predict_type", type=str, default=None, required=True,
+                        help="Portion of the data to perform prediction on (e.g., dev, test).")
+    parser.add_argument("--model_type", type=str, default=None, required=True,
+                        help="Model type.",
+                        choices=list(MODEL_CLASSES.keys()))
+    parser.add_argument("--model_name_or_path",type=str, default=None, required=True,
+                        help="Path to pretrained model or model identifier from huggingface.co/models.")
+    parser.add_argument("--output_dir", type=str, default=None, required=True,
+                        help="The output directory where the model checkpoints and predictions will be written.")
+
+    # Other parameters
+    parser.add_argument("--max_seq_length", type=int, default=180,
+                        help="Maximum input length after tokenization. "
+                             "Longer sequences will be truncated, shorter ones padded.")
+    parser.add_argument("--do_train", action='store_true',
+                        help="Whether to run training.")
+    parser.add_argument("--do_eval", action='store_true',
+                        help="Whether to run eval on the <predict_type> set.")
+    parser.add_argument("--evaluate_during_training", action='store_true',
+                        help="Run evaluation during training at each logging step.")
+    parser.add_argument("--do_lower_case", action='store_true',
+                        help="Set this flag if you are using an uncased model.")
+
+    parser.add_argument("--dropout_rate", type=float, default=0.3,
+                        help="Dropout rate for transformer encoder representations.")
+    parser.add_argument("--heads_dropout", type=float, default=0.0,
+                        help="Dropout rate for classification heads.")
+    parser.add_argument("--class_loss_ratio", type=float, default=0.8,
+                        help="The ratio applied on class loss in total loss calculation. "
+                             "Should be a value in [0.0, 1.0]. "
+                             "The ratio applied on token loss is (1-class_loss_ratio)/2. "
+                             "The ratio applied on refer loss is (1-class_loss_ratio)/2.")
+    parser.add_argument("--token_loss_for_nonpointable", action='store_true',
+                        help="Whether the token loss for classes other than copy_value contribute towards total loss.")
+    parser.add_argument("--refer_loss_for_nonpointable", action='store_true',
+                        help="Whether the refer loss for classes other than refer contribute towards total loss.")
+
+    parser.add_argument("--no_append_history", action='store_true',
+                        help="Whether or not to append the dialog history to each turn.")
+    parser.add_argument("--no_use_history_labels", action='store_true',
+                        help="Whether or not to label the history as well.")
+    parser.add_argument("--no_label_value_repetitions", action='store_true',
+                        help="Whether or not to label values that have been mentioned before.")
+    parser.add_argument("--swap_utterances", action='store_true',
+                        help="Whether or not to swap the turn utterances (default: usr|sys, swapped: sys|usr).")
+    parser.add_argument("--delexicalize_sys_utts", action='store_true',
+                        help="Whether or not to delexicalize the system utterances.")
+    parser.add_argument("--class_aux_feats_inform", action='store_true',
+                        help="Whether or not to use the identity of informed slots as auxiliary featurs for class prediction.")
+    parser.add_argument("--class_aux_feats_ds", action='store_true',
+                        help="Whether or not to use the identity of slots in the current dialog state as auxiliary featurs for class prediction.")
+
+    parser.add_argument("--train_batch_size", type=int, default=8,
+                        help="Batch size for the training dataloader.")
+    parser.add_argument("--eval_batch_size", type=int, default=8,
+                        help="Batch size for the evaluation dataloader.")
+    parser.add_argument("--learning_rate", type=float, default=5e-5,
+                        help="Initial learning rate (after the potential warmup period) to use.")
+    parser.add_argument("--rescaler_learning_rate", type=float, default=1e-2,
+                        help="Initial learning rate (after the potential warmup period) to use.")
+    parser.add_argument("--weight_decay", type=float, default=0.0,
+                        help="Weight decay to use.")
+    parser.add_argument("--adam_epsilon", type=float, default=1e-6,
+                        help="Epsilon for Adam optimizer.")
+    parser.add_argument("--num_train_epochs", type=int, default=3,
+                        help="Total number of training epochs to perform.")
+    parser.add_argument("--max_train_steps", type=int, default=None,
+                        help="Total number of training steps to perform. If provided, overrides num_train_epochs.")
+    parser.add_argument("--optimizer", type=str, default='AdamW',
+                        help="Optimizer to use.",
+                        choices=["Adam", "AdamW"])
+    parser.add_argument("--lr_scheduler_type", type=str, default="linear",
+                        help="The scheduler type to use.",
+                        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"])
+    parser.add_argument("--warmup_proportion", type=float, default=0.0,
+                        help="Proportion of steps for the warmup in the lr scheduler.")
+    parser.add_argument("--svd", type=float, default=0.0,
+                        help="Slot value dropout ratio.")
+
+    parser.add_argument('--seed', type=int, default=42,
+                        help="random seed for initialization")
+    parser.add_argument('--local_files_only', action='store_true',
+                        help="Whether to only load local model files (useful when working offline).")
+    parser.add_argument('--logging_steps', type=int, default=100,
+                        help="Log every X updates steps.")
+    parser.add_argument('--save_steps', type=int, default=0,
+                        help="Save checkpoint every X updates steps. Overwritten by --save_epochs.")
+    parser.add_argument('--save_epochs', type=int, default=0,
+                        help="Save checkpoint every X epochs. Overrides --save_steps.")
+    parser.add_argument("--save_stats", action='store_true',
+                        help="When set, saves detailed training statistics as gzipped json.")
+    parser.add_argument("--eval_all_checkpoints", action='store_true',
+                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
+    parser.add_argument('--overwrite_output_dir', action='store_true',
+                        help="Overwrite the content of the output directory")
+    parser.add_argument('--overwrite_cache', action='store_true',
+                        help="Overwrite the cached training and evaluation sets")
+
+    parser.add_argument('--stats_window', type=int, default=1,
+                        help="Window size for sample statistics memory (in number of batches).")
+    parser.add_argument("--simulate_only", action='store_true',
+                        help="When set, rescaling is not actually applied (but still learned).")
+    parser.add_argument('--dropout_rounds', type=int, default=2,
+                        help="Number of additional forward passes to compute sample statistics.")
+    parser.add_argument("--no_cat", action='store_true',
+                        help="When set, sample statistics are computed by label-prediction agreement.")
+    parser.add_argument("--agra", action='store_true',
+                        help="Use original AGRA method instead of rescaler. "
+                             "AGRA does not use meta learning.")
+    parser.add_argument("--agra_weighted", action='store_true',
+                        help="When set, uses class weighting to sample comparison batches.")
+    parser.add_argument("--agra_loss", type=str, default="CE",
+                        help="Comparison loss for AGRA.",
+                        choices=["CE", "F1"])
+    parser.add_argument("--meta_dataset", type=str, default="train",
+                        help="Dataset used for the meta update step (default for STORM: 'train').",
+                        choices=["train", "eval"])
+    parser.add_argument('--meta_innerloop_rounds', type=int, default=1,
+                        help="Number of inner loop traversals for meta learning.")
+    parser.add_argument("--rescaler_binary", action='store_true',
+                        help="When set, rescaler will produce binary loss weights, i.e., 0 or 1.")
+    parser.add_argument("--rescaler_binary_threshold", type=float, default=0.5,
+                        help="Confidence threshold for binary rescaler.",
+                        choices=[0.1, 0.2, 0.3, 0.4, 0.5])
+    parser.add_argument("--rescaler_feats", type=str, default="default",
+                        help="Feature set for rescaler.",
+                        choices=["default", "loss"])
+    parser.add_argument("--rescaler_feats_no_cat", action='store_true',
+                        help="When set, cat is not added to the feature set of the rescaler.")
+    parser.add_argument("--no_class_separation", action='store_true',
+                        help="When set, sample statistics are not separated by target classes.")
+    parser.add_argument("--no_meta_loss_rescaling", action='store_true',
+                        help="When set, no meta loss rescaling is used.")
+
+    args = parser.parse_args()
+
+    assert args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0
+    assert args.svd >= 0.0 and args.svd <= 1.0
+    assert not args.class_aux_feats_ds or args.eval_batch_size == 1
+    assert not args.class_aux_feats_inform or args.eval_batch_size == 1
+
+    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    
+    return args
+
+
+def train(args, train_dataset, eval_dataset, features, model, tokenizer, processor):
+    tb_writer = SummaryWriter()
+
+    train_sampler = RandomSampler(train_dataset)
+    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, drop_last=True)
+
+    if args.max_train_steps is not None:
+        t_total = args.max_train_steps
+        args.num_train_epochs = args.max_train_steps // len(train_dataloader) + 1
+    else:
+        t_total = len(train_dataloader) * args.num_train_epochs
+
+    if args.save_epochs > 0:
+        args.save_steps = t_total // args.num_train_epochs * args.save_epochs
+
+    num_warmup_steps = int(t_total * args.warmup_proportion)
+
+    # Optimizer
+    # Split weights in two groups, one with weight decay and the other not.
+    no_decay = ['bias', 'LayerNorm.weight']
+    optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and not "rescaler" in n],
+            "weight_decay": args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and not "rescaler" in n],
+            "weight_decay": 0.0,
+        },
+    ]
+    inner_optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and not "rescaler" in n],
+            "weight_decay": args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and not "rescaler" in n],
+            "weight_decay": 0.0,
+        },
+    ]
+    outer_optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and "rescaler" in n],
+            "weight_decay": args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and "rescaler" in n],
+            "weight_decay": 0.0,
+        },
+    ]
+    inner_optimizer = torch.optim.Adam(inner_optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
+    if args.optimizer == "Adam":
+        optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
+        meta_optimizer = torch.optim.Adam(outer_optimizer_grouped_parameters, lr=args.rescaler_learning_rate, eps=args.adam_epsilon)
+    else:
+        optimizer = transformers.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
+        meta_optimizer = transformers.AdamW(outer_optimizer_grouped_parameters, lr=args.rescaler_learning_rate, eps=args.adam_epsilon)
+
+    # This would jointly regulate the lr for optimizer and inner_optimizer
+    # if both optimizers would use the same grouped_parameters, i.e.,
+    # the lr seems to be tied to the parameters, not to the optimizer.
+    scheduler = transformers.get_scheduler(
+        name=args.lr_scheduler_type,
+        optimizer=inner_optimizer,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=t_total,
+    )
+
+    if args.agra:
+        agra = AGRA(args.agra_loss,
+                    model.class_labels,
+                    args.agra_weighted,
+                    train_dataset,
+                    model=model,
+                    device=args.device,
+                    window_size=args.stats_window)
+
+    # Train!
+    logger.info("***** Running training *****")
+    logger.info(f"  Num examples = {len(train_dataset)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Total train batch size = {args.train_batch_size}")
+    logger.info(f"  Total optimization steps = {t_total}")
+    logger.info(f"  Warmup steps = {num_warmup_steps}")
+
+    global_step = 0
+    tr_loss, logging_loss = 0.0, 0.0
+    model.zero_grad()
+    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False)
+
+    stat_list = {}
+    filter_list_by_epoch = {-1: []}
+    for epoch_itr, _ in enumerate(train_iterator):
+        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
+
+        train_inner_sampler = RandomSampler(train_dataset)
+        train_inner_dataloader = DataLoader(
+            train_dataset, sampler=train_inner_sampler, batch_size=args.train_batch_size, drop_last=True)
+
+        # Meta dataloader
+        if args.meta_dataset == "eval":
+            meta_dataset = eval_dataset
+        else:
+            meta_dataset = train_dataset
+        meta_sampler = RandomSampler(meta_dataset)
+        meta_dataloader = DataLoader(
+            meta_dataset, sampler=meta_sampler, batch_size=args.train_batch_size, drop_last=True)
+
+        if args.agra:
+            agra.build_dataloader(args.train_batch_size)
+
+        model.train()
+
+        filter_list_by_epoch[epoch_itr] = []
+        # Setting this to 1 will cause Filter to not separate stats by target class
+        num_filter_labels = 1 if args.no_class_separation else model.class_labels
+        epoch_filter = Filter(num_filter_labels,
+                              args.stats_window,
+                              args.train_batch_size)
+        if not args.no_meta_loss_rescaling:
+            meta_epoch_filter = Filter(num_filter_labels,
+                                       args.stats_window,
+                                       args.train_batch_size)
+        for step, orig_batch in enumerate(epoch_iterator):
+            orig_batch = to_device(batch_to_dict(orig_batch), args.device)
+            batch_size = orig_batch['input_ids'].size(0)
+
+            meta_batch = next(iter(meta_dataloader))
+            meta_batch = to_device(batch_to_dict(meta_batch), args.device)
+            meta_optimizer.zero_grad()
+
+            total_inner_rescaler_outputs = {}
+            with higher.innerloop_ctx(model, inner_optimizer, copy_initial_weights=False, track_higher_grads=True) as (fmodel, diffopt):
+                for ir in range(args.meta_innerloop_rounds):
+                    if ir > 0:
+                        batch = next(iter(train_inner_dataloader))
+                        batch = to_device(batch_to_dict(batch), args.device)
+                    else:
+                        batch = orig_batch
+
+                    if args.agra:
+                        grad_scores = agra.agra_step(batch)
+
+                    outputs = fmodel(**batch)
+
+                    # These are the losses we use for backprop
+                    class_losses = outputs[2][0]
+                    token_losses = outputs[4][0]
+                    refer_losses = outputs[7][0]
+                    outputs = from_device(outputs)
+                    labels = from_device(batch['class_label_id'])
+                    step_results = Results(outputs[2], outputs[3], labels)
+
+                    if args.agra:
+                        step_results['grad_scores'] = grad_scores
+
+                    # Get batch loss/grad statistics, accumulated across samples
+                    epoch_filter.update_batch_stats(step_results)
+
+                    # Get stats
+                    kls, ovls = epoch_filter.get_sample_stats(batch,
+                                                              step_results,
+                                                              args.no_cat)
+
+                    # ---------------------------
+                    # Rescale sample losses
+                    # ---------------------------
+
+                    total_loss = None # Recompute total loss as in modeling_dst.py
+                    filter_ids = {}
+                    for slot in model.slot_list:
+                        if args.agra:
+                            # AGRA baseline
+                            inner_rescaler_outputs = (grad_scores[slot] >= 0).float().to(args.device)
+                        else:
+                            # Get weights for each sample for each slot gate
+                            inner_rescaler_outputs = call_rescaler(args, fmodel, slot,
+                                                                   epoch_filter, step_results,
+                                                                   kls, ovls,
+                                                                   class_losses, batch)
+
+                        # Rescale
+                        if not args.simulate_only:
+                            scaling_factor = class_losses[slot].sum()
+                            class_losses[slot] *= inner_rescaler_outputs
+                            scaling_factor /= class_losses[slot].sum()
+                            if not args.no_meta_loss_rescaling:
+                                class_losses[slot] *= scaling_factor
+                        if ir == 0:
+                            total_inner_rescaler_outputs[slot] = inner_rescaler_outputs
+
+                        if 'refer' in model.class_types:
+                            per_example_loss = \
+                                (args.class_loss_ratio) * class_losses[slot] + \
+                                ((1 - args.class_loss_ratio) / 2) * token_losses[slot] + \
+                                ((1 - args.class_loss_ratio) / 2) * refer_losses[slot]
+                        else:
+                            per_example_loss = \
+                                args.class_loss_ratio * class_losses[slot] + \
+                                (1 - args.class_loss_ratio) * token_losses[slot]
+
+                        if total_loss is None:
+                            total_loss = per_example_loss.sum()
+                        else:
+                            total_loss += per_example_loss.sum()
+
+                        # Get stats
+                        if ir == 0:
+                            filter_ids[slot] = batch['ids'][(inner_rescaler_outputs < 0.5).cpu()].tolist()
+
+                    # Get stats
+                    if ir == 0:
+                        filter_list_by_epoch[epoch_itr] += [x for xs in list(filter_ids.values()) for x in xs]
+
+                        for l_itr in range(batch_size):
+                            batch_id = batch['ids'][l_itr].item()
+
+                            # Collect statistics
+                            if batch_id not in stat_list:
+                                stat_list[batch_id] = []
+                            stat_list[batch_id].append(
+                                {"is_filtered": {s: batch_id in filter_ids[s] for s in model.slot_list},
+                                 "epoch": epoch_itr,
+                                 "label": {s: features[batch_id].class_label_id[s] for s in model.slot_list},
+                                 "weight": {s: total_inner_rescaler_outputs[s][l_itr].item() for s in model.slot_list},
+                                 "guid": features[batch_id].guid,
+                                 "pred": {s: step_results["preds"][s][0][l_itr].item() for s in model.slot_list},
+                                 "probs": {s: step_results["probs"][s][0][l_itr].tolist() for s in model.slot_list}})
+
+                    diffopt.step(total_loss) # Perform a training step
+
+                    if ir > 0:
+                        batch = from_device(batch)
+
+                # End of inner loop(s)
+
+                # Meta update
+                if not args.agra:
+                    meta_outputs = fmodel(**meta_batch)
+                    meta_loss = meta_outputs[0][0]
+                    meta_class_losses = meta_outputs[2][0]
+                    meta_token_losses = meta_outputs[4][0]
+                    meta_refer_losses = meta_outputs[7][0]
+                    meta_outputs = from_device(meta_outputs)
+                    meta_labels = from_device(meta_batch['class_label_id'])
+
+                    if not args.no_meta_loss_rescaling:
+                        meta_step_results = Results(meta_outputs[2], meta_outputs[3], meta_labels)
+                        meta_epoch_filter.update_batch_stats(meta_step_results)
+                        meta_kls, meta_ovls = meta_epoch_filter.get_sample_stats(meta_batch,
+                                                                                 meta_step_results,
+                                                                                 args.no_cat)
+
+                        meta_total_loss = None # Recompute total loss as in modeling_dst.py
+                        for slot in model.slot_list:
+                            rescaler_outputs = call_rescaler(args, fmodel, slot,
+                                                             meta_epoch_filter, meta_step_results,
+                                                             meta_kls, meta_ovls,
+                                                             meta_class_losses, meta_batch)
+                            if not args.simulate_only:
+                                meta_scaling_factor = meta_class_losses[slot].sum()
+                                meta_class_losses[slot] *= rescaler_outputs
+                                meta_scaling_factor /= meta_class_losses[slot].sum()
+                                meta_class_losses[slot] *= meta_scaling_factor
+
+                            if 'refer' in model.class_types:
+                                meta_per_example_loss = \
+                                    (args.class_loss_ratio) * meta_class_losses[slot] + \
+                                    ((1 - args.class_loss_ratio) / 2) * meta_token_losses[slot] + \
+                                    ((1 - args.class_loss_ratio) / 2) * meta_refer_losses[slot]
+                            else:
+                                meta_per_example_loss = \
+                                    args.class_loss_ratio * meta_class_losses[slot] + \
+                                    (1 - args.class_loss_ratio) * meta_token_losses[slot]
+                            if meta_total_loss is None:
+                                meta_total_loss = meta_per_example_loss.sum()
+                            else:
+                                meta_total_loss += meta_per_example_loss.sum()
+                        meta_loss = meta_total_loss
+
+                    meta_loss.backward()
+
+            # Meta update
+            if not args.agra:
+                meta_optimizer.step()
+                meta_optimizer.zero_grad()
+
+            # Update model parameters
+            model.zero_grad()
+            outputs = model(**orig_batch)
+            class_losses = outputs[2][0]
+            token_losses = outputs[4][0]
+            refer_losses = outputs[7][0]
+            outputs = from_device(outputs)
+            total_loss = None # Recompute total loss as in modeling_dst.py
+            for slot in model.slot_list:
+                if not args.simulate_only:
+                    class_losses[slot] *= total_inner_rescaler_outputs[slot].clone().detach()
+                if 'refer' in model.class_types:
+                    per_example_loss = \
+                        (args.class_loss_ratio) * class_losses[slot] + \
+                        ((1 - args.class_loss_ratio) / 2) * token_losses[slot] + \
+                        ((1 - args.class_loss_ratio) / 2) * refer_losses[slot]
+                else:
+                    per_example_loss = \
+                        args.class_loss_ratio * class_losses[slot] + \
+                        (1 - args.class_loss_ratio) * token_losses[slot]
+                if total_loss is None:
+                    total_loss = per_example_loss.sum()
+                else:
+                    total_loss += per_example_loss.sum()
+            tr_loss += total_loss.item()
+            total_loss.backward()
+            for o, oi in zip(optimizer.param_groups, inner_optimizer.param_groups):
+                o['lr'] = oi['lr']
+            optimizer.step()
+            optimizer.zero_grad()
+            if args.agra:
+                autograd_hacks.clear_grad1(model)
+
+            orig_batch = from_device(orig_batch)
+
+            scheduler.step()
+            global_step += 1
+            
+            # Log metrics
+            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
+                tb_writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step)
+                tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
+                logging_loss = tr_loss
+
+            # Save model checkpoint
+            if args.save_steps > 0 and global_step % args.save_steps == 0:
+                output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
+                logger.info("Saving model checkpoint to %s", output_dir)
+                if not os.path.exists(output_dir):
+                    os.makedirs(output_dir)
+                model.save_pretrained(output_dir)
+
+            if args.max_train_steps is not None and global_step > args.max_train_steps:
+                epoch_iterator.close()
+                break
+
+        logger.info("")
+        logger.info("Filtered (loss weight < 0.5) samples in epoch %s: %s of %s" % (epoch_itr,
+                                                                                    len(filter_list_by_epoch[epoch_itr]),
+                                                                                    len(features) * len(model.slot_list)))
+        tb_writer.add_scalars('filtered_samples', {'Filtered samples': len(filter_list_by_epoch[epoch_itr]),
+                                                   'Total': len(features) * len(model.slot_list)}, epoch_itr)
+
+        # Evaluate task performance
+        if args.evaluate_during_training:
+            results = evaluate(args, model, tokenizer, processor, prefix=global_step)
+            logger.info(f"epoch {epoch_itr}: {results}")
+            for key, value in results.items():
+                tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
+
+        if args.max_train_steps is not None and global_step > args.max_train_steps:
+            train_iterator.close()
+            break
+
+    tb_writer.close()
+
+    if args.output_dir is not None and args.save_stats:
+        with gzip.open(os.path.join(args.output_dir, "stat_list.%s.json.gz" % (args.predict_type)), "w") as f:
+            f.write(json.dumps(stat_list, indent=2).encode('utf-8'))
+
+    return global_step, tr_loss / global_step
+
+
+def evaluate(args, model, tokenizer, processor, prefix=""):
+    dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True)
+
+    if not os.path.exists(args.output_dir):
+        os.makedirs(args.output_dir)
+
+    args.eval_batch_size = args.eval_batch_size
+    eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly
+    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
+
+    # Eval!
+    logger.info("***** Running evaluation {} *****".format(prefix))
+    logger.info("  Num examples = %d", len(dataset))
+    logger.info("  Batch size = %d", args.eval_batch_size)
+    all_results = []
+    all_preds = []
+    ds = {slot: 'none' for slot in model.slot_list}
+    with torch.no_grad():
+        diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list}
+    for batch in tqdm(eval_dataloader, desc="Evaluating"):
+        model.eval()
+        batch = to_device(batch, args.device)
+
+        # Reset dialog state if turn is first in the dialog.
+        turn_itrs = [features[i.item()].guid.split('-')[-1] for i in batch[9]]
+        reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
+        for slot in model.slot_list:
+            for i in reset_diag_state:
+                diag_state[slot][i] = 0
+
+        with torch.no_grad():
+            inputs = {'input_ids':       batch[0],
+                      'input_mask':      batch[1],
+                      'segment_ids':     batch[2],
+                      'start_pos':       batch[3],
+                      'end_pos':         batch[4],
+                      'inform_slot_id':  batch[5],
+                      'refer_id':        batch[6],
+                      'diag_state':      diag_state,
+                      'class_label_id':  batch[8],
+                      'suppress_dropout_passes': True}
+            unique_ids = [features[i.item()].guid for i in batch[9]]
+            values = [features[i.item()].values for i in batch[9]]
+            input_ids_unmasked = [features[i.item()].input_ids_unmasked for i in batch[9]]
+            inform = [features[i.item()].inform for i in batch[9]]
+            outputs = model(**inputs)
+
+            # Update dialog state for next turn.
+            for slot in model.slot_list:
+                updates = outputs[3][0][slot].max(1)[1]
+                for i, u in enumerate(updates):
+                    if u != 0:
+                        diag_state[slot][i] = u
+
+        results = eval_metric(model, inputs, outputs[0][0], outputs[1][0], outputs[3][0], outputs[5][0], outputs[6][0], outputs[8][0])
+        preds, ds = predict_and_format(model, tokenizer, inputs, outputs[3][0], outputs[5][0], outputs[6][0], outputs[8][0], unique_ids, input_ids_unmasked, values, inform, prefix, ds)
+        all_results.append(results)
+        all_preds.append(preds)
+
+    all_preds = [item for sublist in all_preds for item in sublist] # Flatten list
+
+    # Generate final results
+    final_results = {}
+    for k in all_results[0].keys():
+        final_results[k] = torch.stack([r[k] for r in all_results]).mean()
+
+    # Write final predictions (for evaluation with external tool)
+    output_prediction_file = os.path.join(args.output_dir, "pred_res.%s.%s.json" % (args.predict_type, prefix))
+    with open(output_prediction_file, "w") as f:
+        json.dump(all_preds, f, indent=2)
+
+    return final_results
+
+
+def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False):
+    # Load data features from cache or dataset file
+    cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features'.format(
+        args.predict_type if evaluate else 'train'))
+    if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples:
+        logger.info("Loading features from cached file %s", cached_file)
+        features = torch.load(cached_file)
+    else:
+        if args.task_name == "unified":
+            logger.info("Creating features from unified data format")
+        else:
+            logger.info("Creating features from dataset file at %s", args.data_dir)
+        processor_args = {'no_append_history': args.no_append_history,
+                          'no_use_history_labels': args.no_use_history_labels,
+                          'no_label_value_repetitions': args.no_label_value_repetitions,
+                          'swap_utterances': args.swap_utterances,
+                          'delexicalize_sys_utts': args.delexicalize_sys_utts,
+                          'unk_token': '<unk>' if args.model_type == 'roberta' else '[UNK]'}
+        if evaluate and args.predict_type == "dev":
+            examples = processor.get_dev_examples(args.data_dir, processor_args)
+        elif evaluate and args.predict_type == "test":
+            examples = processor.get_test_examples(args.data_dir, processor_args)
+        else:
+            examples = processor.get_train_examples(args.data_dir, processor_args)
+        features = convert_examples_to_features(examples=examples,
+                                                slot_list=model.slot_list,
+                                                class_types=model.class_types,
+                                                model_type=args.model_type,
+                                                tokenizer=tokenizer,
+                                                max_seq_length=args.max_seq_length,
+                                                slot_value_dropout=(0.0 if evaluate else args.svd))
+        logger.info("Saving features into cached file %s", cached_file)
+        torch.save(features, cached_file)
+
+    # Convert to Tensors and build dataset
+    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
+    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
+    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
+    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
+    f_start_pos = [f.start_pos for f in features]
+    f_end_pos = [f.end_pos for f in features]
+    f_inform_slot_ids = [f.inform_slot for f in features]
+    f_refer_ids = [f.refer_id for f in features]
+    f_diag_state = [f.diag_state for f in features]
+    f_class_label_ids = [f.class_label_id for f in features]
+    all_start_positions = {}
+    all_end_positions = {}
+    all_inform_slot_ids = {}
+    all_refer_ids = {}
+    all_diag_state = {}
+    all_class_label_ids = {}
+    for s in model.slot_list:
+        all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long)
+        all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], dtype=torch.long)
+        all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long)
+        all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long)
+        all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long)
+        all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long)
+    dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids,
+                                all_start_positions, all_end_positions,
+                                all_inform_slot_ids,
+                                all_refer_ids,
+                                all_diag_state,
+                                all_class_label_ids, all_example_index)
+
+    return dataset, features
+
+
+def call_rescaler(args, model, slot, epoch_filter, step_results, kls, ovls, class_losses, batch):
+    d_cat = int(not args.no_cat) * (2 - step_results['agreement'][slot].type(torch.int)) # 0, 1 or 2 (agree or disagree)
+
+    rlbls = torch.full(batch['class_label_id'][slot].size(), 0) if args.no_class_separation else from_device(batch['class_label_id'][slot])
+    if args.rescaler_feats == "default":
+        feats = torch.cat((step_results['losses_means'][slot].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_means_mean', slot)[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_means_std', slot)[rlbls, d_cat].reshape(-1,1),
+                           step_results['losses_stds'][slot].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_stds_mean', slot)[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('loss_stds_std', slot)[rlbls, d_cat].reshape(-1,1),
+                           step_results['probs_means'][slot].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_means_mean', slot)[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_means_std', slot)[rlbls, d_cat].reshape(-1,1),
+                           step_results['probs_stds'][slot].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_stds_mean', slot)[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('prob_stds_std', slot)[rlbls, d_cat].reshape(-1,1),
+                           torch.tensor(kls[slot]).reshape(-1,1),
+                           epoch_filter.get_stats_tensor('kl_mean', slot)[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('kl_std', slot)[rlbls, d_cat].reshape(-1,1),
+                           torch.tensor(ovls[slot]).reshape(-1,1),
+                           epoch_filter.get_stats_tensor('ovl_mean', slot)[rlbls, d_cat].reshape(-1,1),
+                           epoch_filter.get_stats_tensor('ovl_std', slot)[rlbls, d_cat].reshape(-1,1)),
+                          dim=1).type(torch.float).to(args.device)
+    elif args.rescaler_feats == "loss":
+        feats = class_losses[slot].reshape(-1,1).detach().clone() # Loss w/o grad
+    else:
+        raise Exception("Unknown meta filter features.")
+    if not args.rescaler_feats_no_cat:
+        feats = torch.cat((feats, step_results['agreement'][slot].unsqueeze(1).type(torch.float).to(args.device)), dim=1)
+    rescaler_outputs = model(feats=feats, mode="rescaler")
+    if not args.no_class_separation:
+        rescaler_outputs = rescaler_outputs[0]
+        new_rescaler_outputs = rescaler_outputs[0] * (step_results['labels'][slot] == 0).unsqueeze(1).to(args.device)
+        for c in range(1, len(model.class_types)):
+            new_rescaler_outputs += rescaler_outputs[c] * (step_results['labels'][slot] == c).unsqueeze(1).to(args.device)
+        rescaler_outputs = (new_rescaler_outputs,)
+    rescaler_outputs = rescaler_outputs[0][:,1]
+    return rescaler_outputs
+
+
+def main():
+    args = parse_args()
+
+    task_name = args.task_name.lower()
+    if task_name not in PROCESSORS:
+        raise ValueError("Task not found: %s" % (task_name))
+
+    processor = PROCESSORS[task_name](args.dataset_config)
+    dst_slot_list = processor.slot_list
+    dst_class_types = processor.class_types
+    dst_class_labels = len(dst_class_types)
+
+    # Setup logging, print header
+    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
+                        datefmt = '%m/%d/%Y %H:%M:%S',
+                        level = logging.INFO)
+    logger.setLevel(logging.INFO)
+    transformers.utils.logging.set_verbosity_info()
+
+    print_header()
+
+    for a in vars(args):
+        logger.info('  {:40s} {:s}'.format(a, str(getattr(args, a))))
+    logger.info("")
+
+    # If passed along, set the training seed now.
+    if args.seed is not None:
+        transformers.set_seed(args.seed)
+
+    # Handle the repository creation
+    if args.output_dir is not None:
+        os.makedirs(args.output_dir, exist_ok=True)
+
+    args.model_type = args.model_type.lower()
+    config_class, model_class, default_class, tokenizer_class = MODEL_CLASSES[args.model_type]
+
+    # Load pretrained model and tokenizer
+    config = config_class.from_pretrained(args.model_name_or_path, local_files_only=args.local_files_only)
+
+    # Add DST specific parameters to config
+    config.dst_max_seq_length = args.max_seq_length
+    config.dst_dropout_rate = args.dropout_rate
+    config.dst_heads_dropout_rate = args.heads_dropout
+    config.dst_class_loss_ratio = args.class_loss_ratio
+    config.dst_token_loss_for_nonpointable = args.token_loss_for_nonpointable
+    config.dst_refer_loss_for_nonpointable = args.refer_loss_for_nonpointable
+    config.dst_class_aux_feats_inform = args.class_aux_feats_inform
+    config.dst_class_aux_feats_ds = args.class_aux_feats_ds
+    config.dst_slot_list = dst_slot_list
+    config.dst_class_types = dst_class_types
+    config.dst_class_labels = dst_class_labels
+
+    # Add STORM specific parameters to config
+    config.dropout_rounds = args.dropout_rounds
+    config.no_class_separation = args.no_class_separation
+    config.rescaler_binary = args.rescaler_binary
+    config.rescaler_binary_threshold = args.rescaler_binary_threshold
+    if args.rescaler_feats == "default":
+        config.rescaler_featnum = 18
+    elif args.rescaler_feats == "loss":
+        config.rescaler_featnum = 1
+    else:
+        raise Exception("Unknown rescaler features.")
+    if not args.rescaler_feats_no_cat:
+        config.rescaler_featnum += 1
+
+    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path,
+                                                do_lower_case=args.do_lower_case,
+                                                local_files_only=args.local_files_only)
+    model = model_class.from_pretrained(args.model_name_or_path,
+                                        from_tf=bool('.ckpt' in args.model_name_or_path),
+                                        config=config,
+                                        local_files_only=args.local_files_only)
+
+    logger.info("Updated model config: %s" % config)
+
+    model.to(args.device)
+
+    logger.info("Training/evaluation parameters %s", args)
+
+    # Training
+    if args.do_train:
+        train_dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=False)
+        eval_dataset, _ = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True)
+        global_step, tr_loss = train(args, train_dataset, eval_dataset, features, model, tokenizer, processor)
+        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
+
+        if args.output_dir is not None and args.save_steps > 0:
+            logger.info("Saving model checkpoint to %s", args.output_dir)
+            model.save_pretrained(args.output_dir)
+            tokenizer.save_pretrained(args.output_dir)
+        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
+
+    # Evaluation
+    results = []
+    if args.do_eval:
+        checkpoints = [args.output_dir]
+        if args.eval_all_checkpoints:
+            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
+            logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
+
+        logger.info("Evaluate the following checkpoints: %s", checkpoints)
+
+        for cItr, checkpoint in enumerate(checkpoints):
+            # Reload the model
+            global_step = checkpoint.split('-')[-1]
+            if cItr == len(checkpoints) - 1:
+                global_step = "final"
+            model = model_class.from_pretrained(checkpoint)
+            model.to(args.device)
+
+            # Evaluate
+            result = evaluate(args, model, tokenizer, processor, prefix=global_step)
+            result_dict = {k: float(v) for k, v in result.items()}
+            result_dict["global_step"] = global_step
+            results.append(result_dict)
+
+            for key in sorted(result_dict.keys()):
+                logger.info("%s = %s", key, str(result_dict[key]))
+
+        output_eval_file = os.path.join(args.output_dir, "eval_res.%s.json.gz" % (args.predict_type))
+        with gzip.open(output_eval_file, "w") as f:
+            f.write(json.dumps(results, indent=2).encode('utf-8'))
+
+    return results
+
+
+if __name__ == "__main__":
+    main()
diff --git a/utils_gpu.py b/utils_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aad860da56bc5ab63dc6e03086ffde3fb8ce40c
--- /dev/null
+++ b/utils_gpu.py
@@ -0,0 +1,56 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Part of this code is based on the source code of AGRA
+# (arXiv:2306.04502)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.tokenization_utils_base import (BatchEncoding)
+
+
+def from_device(batch):
+    if isinstance(batch, tuple):
+        batch_on_cpu = tuple([from_device(element) for element in batch])
+    elif isinstance(batch, list):
+        batch_on_cpu = [from_device(element) for element in batch]
+    elif isinstance(batch, dict):
+        batch_on_cpu = {k: from_device(v) for k, v in batch.items()}
+    elif isinstance(batch, BatchEncoding):
+        batch_on_cpu = {k: from_device(v) for k, v in batch.items()}
+    else:
+        batch_on_cpu = batch.detach().cpu() if batch is not None else batch
+    return batch_on_cpu
+
+
+def to_device(batch, device):
+    if isinstance(batch, list):
+        return to_device_list(batch, device)
+    batch_on_device = {}
+    for element in batch:
+        if isinstance(batch[element], dict):
+            batch_on_device[element] = {k: v.to(device) for k, v in batch[element].items()}
+        else:
+            batch_on_device[element] = batch[element].to(device)
+    return batch_on_device
+
+
+def to_device_list(batch, device):
+    batch_on_device = []
+    for element in batch:
+        if isinstance(element, dict):
+            batch_on_device.append({k: v.to(device) for k, v in element.items()})
+        else:
+            batch_on_device.append(element.to(device))
+    return batch_on_device
diff --git a/utils_storm.py b/utils_storm.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a7c14344627e9df2ed10d30b2ec1e73ca96a938
--- /dev/null
+++ b/utils_storm.py
@@ -0,0 +1,295 @@
+# coding=utf-8
+#
+# Copyright 2024 Heinrich Heine University Duesseldorf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import torch
+
+import numpy as np
+
+from statistics import (NormalDist)
+
+logger = logging.getLogger(__name__)
+
+
+def print_header():
+    logger.info(" ________  _________  ________  ________  _____ ______       ")
+    logger.info("|\   ____\|\___   ___\\\   __  \|\   __  \|\   _ \  _   \     ")
+    logger.info("\ \  \___|\|___ \  \_\ \  \|\  \ \  \|\  \ \  \\\\\__\ \  \    ")
+    logger.info(" \ \_____  \   \ \  \ \ \  \\\\\  \ \   _  _\ \  \\\|__| \  \   ")
+    logger.info("  \|____|\  \   \ \  \ \ \  \\\\\  \ \  \\\  \\\ \  \    \ \  \  ")
+    logger.info("    ____\_\  \   \ \__\ \ \_______\ \__\\\ _\\\ \__\    \ \__\ ")
+    logger.info("   |\_________\   \|__|  \|_______|\|__|\|__|\|__|     \|__| ")
+    logger.info("   \|_________|    (c) 2024 Heinrich Heine University        ")
+    logger.info("")
+
+
+def gaussian_KL(P_mu, P_std, Q_mu, Q_std):
+    if P_std == 0 or Q_std == 0:
+        return 0.0
+    x = np.log(Q_std / P_std)
+    x += ( pow(P_std, 2) + pow(P_mu - Q_mu, 2) ) / ( 2 * pow(Q_std, 2))
+    x -= 0.5
+    return x
+
+
+class Results:
+    def __init__(self, losses=None, logits=None, labels=None):
+        self.examples = {}
+        if losses is not None and logits is not None and labels is not None:
+            self.update(losses, logits, labels)
+
+
+    def __getitem__(self, idx):
+        return self.examples[idx]
+
+
+    def __setitem__(self, idx, item):
+        self.examples[idx] = item
+
+
+    def __repr__(self):
+        result = ""
+        for e in self.examples:
+            result += "%s: %s\n" % (e, self.examples[e])
+        return result
+
+
+    def update(self, losses, logits, labels):
+        self.examples['losses'] = losses
+        self.examples['logits'] = logits
+        self.examples['labels'] = labels
+        self.examples['probs'] = torch.softmax(self.examples['logits'], dim=2)
+        self.examples['preds'] = torch.argmax(self.examples['logits'], dim=2)
+
+        self._update_agreement(labels)
+        self._update_means()
+
+
+    def _update_agreement(self, labels):
+        agreement = self.examples['preds'][0] == labels # [0] is prediction for which we backprop
+        dropout_agreement_cnt = (self.examples['preds'] == labels).sum(0)
+        dropout_agreement = dropout_agreement_cnt >= self.examples['preds'].size(0) / 2
+        tie = dropout_agreement_cnt == self.examples['preds'].size(0) / 2
+        tie_idx = tie.nonzero(as_tuple=True)[0]
+        dropout_agreement[tie_idx] = agreement[tie_idx]
+        self.examples['agreement'] = dropout_agreement
+
+
+    def _update_means(self):
+        self.examples['losses_means'] = self.examples['losses'].mean(0)
+        self.examples['losses_stds'] = self.examples['losses'].std(0).nan_to_num()
+        self.examples['probs_means'] = self.examples['probs'].max(2)[0].mean(0)
+        self.examples['probs_stds'] = self.examples['probs'].max(2)[0].std(0).nan_to_num()
+
+        
+class Filter():
+    def __init__(self, num_labels, window_size=1, batch_size=1):
+        self.num_labels = num_labels # If num_labels > 1, stats will be kept separate per target class.
+        self.cats = 3
+        self.stats = {}
+        self.eps = 1e-8
+        self.window_size = window_size # window_size=0 -> unlimited window size
+        self.batch_size = batch_size
+
+
+    def get_stats(self, name):
+        return self.stats[name]
+
+
+    def get_stats_tensor(self, name):
+        return torch.tensor(list(self.stats[name].values()))
+
+
+    def __len__(self):
+        return self.cats
+
+
+    def _append_new_stats(self, stats):
+        for e in stats.examples:
+            if stats[e].dim() == 1:
+                if e not in self.stats:
+                    self.stats[e] = {}
+                for l in range(self.num_labels):
+                    if l not in self.stats[e] or self.window_size == 0:
+                        self.stats[e][l] = torch.tensor([], dtype=stats[e].dtype)
+                    if self.num_labels > 1:
+                        self.stats[e][l] = torch.cat((self.stats[e][l], stats[e][stats['labels'] == l]))
+                    else:
+                        self.stats[e][l] = torch.cat((self.stats[e][l], stats[e]))
+                    if self.window_size > 0:
+                        self.stats[e][l] = self.stats[e][l][-1 * self.window_size * self.batch_size:] # sliding window
+
+
+    def update_batch_stats(self, stats):
+        self._append_new_stats(stats)
+
+        self.stats['loss_means_mean'] = {}
+        self.stats['loss_means_std'] = {}
+        self.stats['loss_stds_mean'] = {}
+        self.stats['loss_stds_std'] = {}
+        self.stats['prob_means_mean'] = {}
+        self.stats['prob_means_std'] = {}
+        self.stats['prob_stds_mean'] = {}
+        self.stats['prob_stds_std'] = {}
+        self.stats['kl'] = {}
+        self.stats['kl_mean'] = {}
+        self.stats['kl_std'] = {}
+        self.stats['ovl'] = {}
+        self.stats['ovl_mean'] = {}
+        self.stats['ovl_std'] = {}
+        if 'grad_scores' in stats.examples:
+            self.stats['grad_scores_mean'] = {}
+            self.stats['grad_scores_std'] = {}
+        for l in range(self.num_labels):
+            # Get batch loss statistics, separate by sample, then accumulated
+            (self.stats['loss_means_mean'][l],
+             self.stats['loss_means_std'][l],
+             self.stats['loss_stds_mean'][l],
+             self.stats['loss_stds_std'][l]) = self._update_batch_stats(self.stats['losses_means'][l],
+                                                                        self.stats['losses_stds'][l],
+                                                                        self.stats['agreement'][l])
+
+            # Get batch probability statistics, separate by sample, then accumulated
+            (self.stats['prob_means_mean'][l],
+             self.stats['prob_means_std'][l],
+             self.stats['prob_stds_mean'][l],
+             self.stats['prob_stds_std'][l]) = self._update_batch_stats(self.stats['probs_means'][l],
+                                                                        self.stats['probs_stds'][l],
+                                                                        self.stats['agreement'][l])
+
+            # Get batch loss distribution KL divergence statistics
+            (self.stats['kl'][l],
+             self.stats['kl_mean'][l],
+             self.stats['kl_std'][l]) = self._update_batch_kl(self.stats['losses_means'][l],
+                                                              self.stats['losses_stds'][l],
+                                                              self.stats['loss_means_mean'][l],
+                                                              self.stats['loss_stds_mean'][l],
+                                                              self.stats['agreement'][l],
+                                                              mode="kl")
+
+            # Get batch loss distribution overlap statistics
+            (self.stats['ovl'][l],
+             self.stats['ovl_mean'][l],
+             self.stats['ovl_std'][l]) = self._update_batch_kl(self.stats['losses_means'][l],
+                                                               self.stats['losses_stds'][l],
+                                                               self.stats['loss_means_mean'][l],
+                                                               self.stats['loss_stds_mean'][l],
+                                                               self.stats['agreement'][l],
+                                                               mode="ovl")
+
+            # Get grad statistics
+            if 'grad_scores' in stats.examples:
+                (self.stats['grad_scores_mean'][l],
+                 self.stats['grad_scores_std'][l]) = self._update_global_batch_stats(self.stats['grad_scores'][l],
+                                                                                     self.stats['agreement'][l])
+
+
+    def _update_batch_stats(self, stats_means, stats_stds, agreement):
+        stats_means_mean = [None] * self.cats
+        stats_means_std = [None] * self.cats
+        stats_stds_mean = [None] * self.cats
+        stats_stds_std = [None] * self.cats
+        for cat in range(self.cats):
+            cat_means = self._get_stats_by_agreement_cat(stats_means, cat, agreement)
+            cat_stds = self._get_stats_by_agreement_cat(stats_stds, cat, agreement)
+            stats_means_mean[cat] = cat_means.mean().tolist()
+            stats_means_std[cat] = cat_means.std().tolist() if len(cat_means) > 1 else self.eps
+            stats_stds_mean[cat] = cat_stds.mean().tolist()
+            stats_stds_std[cat] = cat_stds.std().tolist() if len(cat_stds) > 1 else self.eps
+        return (stats_means_mean, stats_means_std, stats_stds_mean, stats_stds_std)
+
+
+    def _update_global_batch_stats(self, stats, agreement):
+        stats_mean = [None] * self.cats
+        stats_std = [None] * self.cats
+        for cat in range(self.cats):
+            cat_stats = self._get_stats_by_agreement_cat(stats, cat, agreement)
+            stats_mean[cat] = cat_stats.mean().tolist()
+            stats_std[cat] = cat_stats.std().tolist() if len(cat_stats) > 1 else self.eps
+        return (stats_mean, stats_std)
+
+    
+    def _update_batch_kl(self, stats_means, stats_stds, batch_stats_means_mean, batch_stats_stds_mean, agreement, mode="kl"):
+        g_mean = [None] * self.cats
+        g_std = [None] * self.cats
+        g = [[] for c in range(self.cats)]
+        for cat in range(self.cats):
+            cat_means = self._get_stats_by_agreement_cat(stats_means, cat, agreement)
+            cat_stds = self._get_stats_by_agreement_cat(stats_stds, cat, agreement)
+            for l_itr in range(len(cat_means)):
+                if mode == "kl":
+                    g[cat].append(gaussian_KL(cat_means[l_itr],
+                                              cat_stds[l_itr],
+                                              batch_stats_means_mean[cat],
+                                              batch_stats_stds_mean[cat]))
+                elif mode == "ovl":
+                    g[cat].append(NormalDist(mu=cat_means[l_itr],
+                                             sigma=max(cat_stds[l_itr], self.eps)).overlap(
+                                                 NormalDist(mu=batch_stats_means_mean[cat],
+                                                            sigma=max(batch_stats_stds_mean[cat], self.eps))))
+                else:
+                    raise Exception("Unknown mode for _update_batch_kl.")
+            g_mean[cat] = torch.tensor(g[cat]).mean().tolist()
+            g_std[cat] = torch.tensor(g[cat]).std().tolist() if len(g[cat]) > 1 else self.eps
+        return (g, g_mean, g_std)
+
+    
+    def _get_stats_by_agreement_cat(self, stats, cat, agreement):
+        if cat == 0:
+            return stats
+        elif cat == 1:
+            if len(stats.size()) == 1:
+                return stats[agreement.nonzero()].reshape(-1)
+            else:
+                return stats[:,agreement.nonzero()].reshape(-1)
+        else:
+            if len(stats.size()) == 1:
+                return stats[(~agreement).nonzero()].reshape(-1)
+            else:
+                return stats[:,(~agreement).nonzero()].reshape(-1)
+
+
+    def get_sample_stats(self, batch, stats, no_cat=False):
+        batch_size = batch['input_ids'].size(0)
+        kls = []
+        ovls = []
+        for l_itr in range(batch_size):
+            d_cat = int(not no_cat) * (2 - int(stats['agreement'][l_itr])) # 0, 1 or 2 (agree or disagree)
+
+            if self.num_labels > 1:
+                rlbl = batch['labels'][l_itr].item()
+            else:
+                rlbl = 0
+                
+            d_losses_mean = stats['losses_means'][l_itr].item()
+            d_losses_std = stats['losses_stds'][l_itr].item()
+            d_probs_mean = stats['probs_means'][l_itr].item()
+            d_probs_std = stats['probs_stds'][l_itr].item()
+            if 'grad_scores' in stats.examples:
+                d_grad_scores = stats['grad_scores'][l_itr].item()
+
+            g_kl_div = gaussian_KL(d_losses_mean,
+                                   d_losses_std,
+                                   self.stats['loss_means_mean'][rlbl][d_cat],
+                                   self.stats['loss_stds_mean'][rlbl][d_cat])
+            ovl = NormalDist(mu=d_losses_mean,
+                             sigma=max(d_losses_std, self.eps)).overlap(
+                                 NormalDist(mu=self.stats['loss_means_mean'][rlbl][d_cat],
+                                            sigma=max(self.stats['loss_stds_mean'][rlbl][d_cat], self.eps)))
+
+            kls.append(g_kl_div)
+            ovls.append(ovl)
+        return kls, ovls