learning.py 44,3 ko
Newer Older
    for example in examples:
        desired = example[dataset.target]
        output = predict(dataset.sanitize(example))
        if output == desired:
            right += 1
            if verbose >= 2:
                print('   OK: got {} for {}'.format(desired, example))
            print('WRONG: got {}, expected {} for {}'.format(
MircoT's avatar
MircoT a validé
                output, desired, example))
    return 1 - (right/len(examples))
def grade_learner(predict, tests):
    """Grades the given learner based on how many tests it passes.
    tests is a list with each element in the form: (values, output)."""
    return mean(int(predict(X) == y) for X, y in tests)


def train_test_split(dataset, start, end):
    """Reserve dataset.examples[start:end] for test; train on the remainder."""
    start = int(start)
    end = int(end)
    train = examples[:start] + examples[end:]
    val = examples[start:end]
    return train, val
def cross_validation(learner, size, dataset, k=10, trials=1):
    """Do k-fold cross_validate and return their mean.
    That is, keep out 1/k of the examples for testing on each of k runs.
    Shuffle the examples first; if trials>1, average over several shuffles.
    Returns Training error, Validataion error"""
    k = k or len(dataset.examples)
        trial_errT = 0
        trial_errV = 0
        for t in range(trials):
            errT, errV = cross_validation(learner, size, dataset,
                                          k=10, trials=1)
            trial_errT += errT
            trial_errV += errV
        return trial_errT/trials, trial_errV/trials
        examples = dataset.examples
        for fold in range(k):
            random.shuffle(dataset.examples)
            train_data, val_data = train_test_split(dataset, fold * (n / k),
                                                    (fold + 1) * (n / k))
            dataset.examples = train_data
            h = learner(dataset, size)
            fold_errT += err_ratio(h, dataset, train_data)
            fold_errV += err_ratio(h, dataset, val_data)
            # Reverting back to original once test is completed
            dataset.examples = examples
        return fold_errT/k, fold_errV/k
# TODO: The function cross_validation_wrapper needs to be fixed. (The while loop runs forever!)
def cross_validation_wrapper(learner, dataset, k=10, trials=1):
    """[Fig 18.8]
    Return the optimal value of size having minimum error
    on validation set.
    err_train: A training error array, indexed by size
    err_val: A validation error array, indexed by size
    """
    err_val = []
    err_train = []
    size = 1
    while True:
        errT, errV = cross_validation(learner, size, dataset, k)
        # Check for convergence provided err_val is not empty
        if (err_train and isclose(err_train[-1], errT, rel_tol=1e-6)):
            best_size = 0
            min_val = math.inf

            i = 0
                if err_val[i] < min_val:
                    min_val = err_val[i]
                    best_size = i
                i += 1
        err_val.append(errV)
        err_train.append(errT)
        print(err_val)
        size += 1


ESHAN PANDEY's avatar
ESHAN PANDEY a validé
def leave_one_out(learner, dataset, size=None):
    """Leave one out cross-validation over the dataset."""
    return cross_validation(learner, size, dataset, k=len(dataset.examples))
# TODO learningcurve needs to fixed
def learningcurve(learner, dataset, trials=10, sizes=None):
        sizes = list(range(2, len(dataset.examples) - 10, 2))
    def score(learner, size):
        random.shuffle(dataset.examples)
        return train_test_split(learner, dataset, 0, size)
    return [(size, mean([score(learner, size) for t in range(trials)]))
            for size in sizes]

# ______________________________________________________________________________
withal's avatar
withal a validé
# The rest of this file gives datasets for machine learning problems.
orings = DataSet(name='orings', target='Distressed',
                 attrnames="Rings Distressed Temp Pressure Flightnum")


zoo = DataSet(name='zoo', target='type', exclude=['name'],
              attrnames="name hair feathers eggs milk airborne aquatic " +
              "predator toothed backbone breathes venomous fins legs tail " +
withal's avatar
withal a validé
              "domestic catsize type")


iris = DataSet(name="iris", target="class",
               attrnames="sepal-len sepal-width petal-len petal-width class")

# ______________________________________________________________________________
# The Restaurant example from [Figure 18.2]
def RestaurantDataSet(examples=None):
    """Build a DataSet of Restaurant waiting examples. [Figure 18.3]"""
    return DataSet(name='restaurant', target='Wait', examples=examples,
                   attrnames='Alternate Bar Fri/Sat Hungry Patrons Price ' +
                   'Raining Reservation Type WaitEstimate Wait')
    branches = {value: (child if isinstance(child, DecisionFork)
                        else DecisionLeaf(child))
                for value, child in branches.items()}
C.G.Vedant's avatar
C.G.Vedant a validé
    return DecisionFork(restaurant.attrnum(attrname), attrname, print, branches)
""" [Figure 18.2]
A decision tree for deciding whether to wait for a table at a hotel.
"""

waiting_decision_tree = T('Patrons',
                          {'None': 'No', 'Some': 'Yes',
                           'Full': T('WaitEstimate',
                                     {'>60': 'No', '0-10': 'Yes',
                                      '30-60': T('Alternate',
                                                 {'No': T('Reservation',
                                                          {'Yes': 'Yes',
                                                           'No': T('Bar', {'No': 'No',
                                                                           'Yes': 'Yes'})}),
                                                  'Yes': T('Fri/Sat', {'No': 'No', 'Yes': 'Yes'})}
                                                 ),
                                      '10-30': T('Hungry',
                                                 {'No': 'Yes',
                                                  'Yes': T('Alternate',
                                                           {'No': 'Yes',
                                                            'Yes': T('Raining',
                                                                     {'No': 'No',
                                                                      'Yes': 'Yes'})})})})})
def SyntheticRestaurant(n=20):
    """Generate a DataSet with n examples."""
MircoT's avatar
MircoT a validé
        example = list(map(random.choice, restaurant.values))
        example[restaurant.target] = waiting_decision_tree(example)
        return example
    return RestaurantDataSet([gen() for i in range(n)])

# ______________________________________________________________________________
def Majority(k, n):
    """Return a DataSet with n k-bit examples of the majority problem:
    k random bits followed by a 1 if more than half the bits are 1, else 0."""
    examples = []
    for i in range(n):
        bits = [random.choice([0, 1]) for i in range(k)]
        bits.append(int(sum(bits) > k / 2))
        examples.append(bits)
    return DataSet(name="majority", examples=examples)

def Parity(k, n, name="parity"):
    """Return a DataSet with n k-bit examples of the parity problem:
    k random bits followed by a 1 if an odd number of bits are 1, else 0."""
    examples = []
    for i in range(n):
        bits = [random.choice([0, 1]) for i in range(k)]
        bits.append(sum(bits) % 2)
        examples.append(bits)
    return DataSet(name=name, examples=examples)

def Xor(n):
    """Return a DataSet with n examples of 2-input xor."""
    return Parity(2, n, name="xor")

withal's avatar
withal a validé
    "2 inputs are chosen uniformly from (0.0 .. 2.0]; output is xor of ints."
    examples = []
    for i in range(n):
        x, y = [random.uniform(0.0, 2.0) for i in '12']
        examples.append([x, y, int(x) != int(y)])
    return DataSet(name="continuous xor", examples=examples)

# ______________________________________________________________________________
def compare(algorithms=None,
            datasets=None,
            k=10, trials=1):
    """Compare various learners on various datasets using cross-validation.
    Print results as a table."""
    algorithms = algorithms or [PluralityLearner, NaiveBayesLearner,                 # default list
                                NearestNeighborLearner, DecisionTreeLearner]         # of algorithms

    datasets = datasets or [iris, orings, zoo, restaurant, SyntheticRestaurant(20),  # default list
                            Majority(7, 100), Parity(7, 100), Xor(100)]              # of datasets

MircoT's avatar
MircoT a validé
    print_table([[a.__name__.replace('Learner', '')] +
                 [cross_validation(a, d, k, trials) for d in datasets]
                header=[''] + [d.name[0:7] for d in datasets], numfmt='%.2f')