learning.py 42,6 ko
Newer Older
withal's avatar
withal a validé
    totals = defaultdict(int)
    for v, w in zip(values, weights):
        totals[v] += w
    return max(totals, key=totals.__getitem__)
Donato Meoli's avatar
Donato Meoli a validé

def RandomForest(dataset, n=5):
    """An ensemble of Decision Trees trained using bagging and feature bagging."""

    def data_bagging(dataset, m=0):
        """Sample m examples with replacement"""
        n = len(dataset.examples)
        return weighted_sample_with_replacement(m or n, dataset.examples, [1] * n)

    def feature_bagging(dataset, p=0.7):
        """Feature bagging with probability p to retain an attribute"""
        inputs = [i for i in dataset.inputs if probability(p)]
        return inputs or dataset.inputs

    def predict(example):
        print([predictor(example) for predictor in predictors])
        return mode(predictor(example) for predictor in predictors)

    predictors = [DecisionTreeLearner(DataSet(examples=data_bagging(dataset), attrs=dataset.attrs,
                                              attr_names=dataset.attr_names, target=dataset.target,
                                              inputs=feature_bagging(dataset))) for _ in range(n)]

    return predict
def WeightedLearner(unweighted_learner):
    """
    [Page 749 footnote 14]
    Given a learner that takes just an unweighted dataset, return
    one that takes also a weight for each example.
    """
Donato Meoli's avatar
Donato Meoli a validé

    def train(dataset, weights):
        return unweighted_learner(replicated_dataset(dataset, weights))
Donato Meoli's avatar
Donato Meoli a validé

def replicated_dataset(dataset, weights, n=None):
    """Copy dataset, replicating each example in proportion to its weight."""
    n = n or len(dataset.examples)
    result = copy.copy(dataset)
    result.examples = weighted_replicate(dataset.examples, weights, n)
    return result

def weighted_replicate(seq, weights, n):
    """
    Return n selections from seq, with the count of each element of
    seq proportional to the corresponding weight (filling in fractions
    randomly).
    >>> weighted_replicate('ABC', [1, 2, 1], 4)
    assert len(seq) == len(weights)
    weights = normalize(weights)
Donato Meoli's avatar
Donato Meoli a validé
    wholes = [int(w * n) for w in weights]
    fractions = [(w * n) % 1 for w in weights]
    return (flatten([x] * nx for x, nx in zip(seq, wholes)) +
            weighted_sample_with_replacement(n - sum(wholes), seq, fractions))
orings = DataSet(name='orings', target='Distressed', attr_names='Rings Distressed Temp Pressure Flightnum')

zoo = DataSet(name='zoo', target='type', exclude=['name'],
              attr_names='name hair feathers eggs milk airborne aquatic predator toothed backbone '
                         'breathes venomous fins legs tail domestic catsize type')
iris = DataSet(name='iris', target='class', attr_names='sepal-len sepal-width petal-len petal-width class')
def RestaurantDataSet(examples=None):
    """
    [Figure 18.3]
    Build a DataSet of Restaurant waiting examples.
    """
    return DataSet(name='restaurant', target='Wait', examples=examples,
                   attr_names='Alternate Bar Fri/Sat Hungry Patrons Price Raining Reservation Type WaitEstimate Wait')
def T(attr_name, branches):
    branches = {value: (child if isinstance(child, DecisionFork) else DecisionLeaf(child))
                for value, child in branches.items()}
    return DecisionFork(restaurant.attr_num(attr_name), attr_name, print, branches)
A decision tree for deciding whether to wait for a table at a hotel.
"""

waiting_decision_tree = T('Patrons',
                          {'None': 'No', 'Some': 'Yes',
                           'Full': T('WaitEstimate',
                                     {'>60': 'No', '0-10': 'Yes',
                                      '30-60': T('Alternate',
                                                 {'No': T('Reservation',
                                                          {'Yes': 'Yes',
                                                           'No': T('Bar', {'No': 'No',
                                                                           'Yes': 'Yes'})}),
                                                  'Yes': T('Fri/Sat', {'No': 'No', 'Yes': 'Yes'})}),
                                      '10-30': T('Hungry',
                                                 {'No': 'Yes',
                                                  'Yes': T('Alternate',
                                                           {'No': 'Yes',
                                                            'Yes': T('Raining',
                                                                     {'No': 'No',
                                                                      'Yes': 'Yes'})})})})})
def SyntheticRestaurant(n=20):
    """Generate a DataSet with n examples."""
Donato Meoli's avatar
Donato Meoli a validé

MircoT's avatar
MircoT a validé
        example = list(map(random.choice, restaurant.values))
        example[restaurant.target] = waiting_decision_tree(example)
Donato Meoli's avatar
Donato Meoli a validé

    return RestaurantDataSet([gen() for _ in range(n)])
    """
    Return a DataSet with n k-bit examples of the majority problem:
    k random bits followed by a 1 if more than half the bits are 1, else 0.
    """
        bits = [random.choice([0, 1]) for _ in range(k)]
        bits.append(int(sum(bits) > k / 2))
    return DataSet(name='majority', examples=examples)
def Parity(k, n, name='parity'):
    """
    Return a DataSet with n k-bit examples of the parity problem:
    k random bits followed by a 1 if an odd number of bits are 1, else 0.
    """
        bits = [random.choice([0, 1]) for _ in range(k)]
        bits.append(sum(bits) % 2)
        examples.append(bits)
    return DataSet(name=name, examples=examples)

def Xor(n):
    """Return a DataSet with n examples of 2-input xor."""
Donato Meoli's avatar
Donato Meoli a validé
    """2 inputs are chosen uniformly from (0.0 .. 2.0]; output is xor of ints."""
        x, y = [random.uniform(0.0, 2.0) for _ in '12']
        examples.append([x, y, x != y])
    return DataSet(name='continuous xor', examples=examples)
Donato Meoli's avatar
Donato Meoli a validé

def compare(algorithms=None, datasets=None, k=10, trials=1):
    """
    Compare various learners on various datasets using cross-validation.
    Print results as a table.
    """
    # default list of algorithms
    algorithms = algorithms or [PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, DecisionTreeLearner]
    # default list of datasets
    datasets = datasets or [iris, orings, zoo, restaurant, SyntheticRestaurant(20),
                            Majority(7, 100), Parity(7, 100), Xor(100)]
    print_table([[a.__name__.replace('Learner', '')] + [cross_validation(a, d, k=k, trials=trials) for d in datasets]
                 for a in algorithms], header=[''] + [d.name[0:7] for d in datasets], numfmt='%.2f')